This is an automated email from the ASF dual-hosted git repository.

github-merge-queue[bot] pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c5361881a fix: Incorrect behavior for `FILTER` on NULLs (#22068)
3c5361881a is described below

commit 3c5361881a52f0fc673983cc0eee213723e12890
Author: Neil Conway <[email protected]>
AuthorDate: Tue May 12 00:52:18 2026 -0400

    fix: Incorrect behavior for `FILTER` on NULLs (#22068)
    
    ## Which issue does this PR close?
    
    - Closes #22067.
    
    ## Rationale for this change
    
    In the grouping code, `accumulate_multiple` and `accumulate_indices`
    take a `BooleanArray` parameter, which has the result of the aggregate's
    `FILTER` clause (if any). Both functions only consider the value bits of
    the array, not the NULL bitmap, which means they consider `NULL` filter
    results to be effectively true, not false.
    
    ## What changes are included in this PR?
    
    * Fix NULL handling in `accumulate_multiple` and `accumulate_indices`
    * Refactor `accumulate_multiple` to be more readable and make use of
    `NullBuffer::union_many`
    * Introduce a new helper, `filter_to_validity`
    * Optimize `filter_to_nulls` to use `filter_to_validity` and avoid
    constructing an unnecessary intermediate `NullBuffer`
    * Add unit tests for NULL handling in `accumulate_multiple` and
    `accumulate_indices`
    * Add SLT tests with SQL repros for both code paths
    
    ## Are these changes tested?
    
    Yes, with new tests added.
    
    ## Are there any user-facing changes?
    
    This changes query behavior for the affected (buggy) queries.
    
    This PR also changes the signature of `filter_to_nulls`, which is
    technically a public API.
    
    ---------
    
    Co-authored-by: Matt Butrovich <[email protected]>
    Co-authored-by: DaniĆ«l Heres <[email protected]>
---
 .../src/aggregate/groups_accumulator/accumulate.rs | 114 ++++++++++++++++-----
 .../src/aggregate/groups_accumulator/nulls.rs      |  27 +++--
 datafusion/functions-aggregate/src/array_agg.rs    |   2 +-
 datafusion/sqllogictest/test_files/aggregate.slt   |  20 ++++
 4 files changed, 125 insertions(+), 38 deletions(-)

diff --git 
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
 
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
index 25f52df611..09e1df4eae 100644
--- 
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
+++ 
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
@@ -23,6 +23,7 @@ use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, 
PrimitiveArray};
 use arrow::buffer::NullBuffer;
 use arrow::datatypes::ArrowPrimitiveType;
 
+use crate::aggregate::groups_accumulator::nulls::filter_to_validity;
 use datafusion_expr_common::groups_accumulator::EmitTo;
 
 /// If the input has nulls, then the accumulator must potentially
@@ -471,7 +472,7 @@ pub fn accumulate<T, F>(
 ///
 /// This method assumes that for any input record index, if any of the value 
column
 /// is null, or it's filtered out by `opt_filter`, then the record would be 
ignored.
-/// (won't be accumulated by `value_fn`)
+/// (Won't be accumulated by `value_fn`)
 ///
 /// # Arguments
 ///
@@ -491,35 +492,28 @@ pub fn accumulate_multiple<T, F>(
     T: ArrowPrimitiveType + Send,
     F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
 {
-    // Calculate `valid_indices` to accumulate, non-valid indices are ignored.
-    // `valid_indices` is a bit mask corresponding to the `group_indices`. An 
index
-    // is considered valid if:
-    // 1. All columns are non-null at this index.
-    // 2. Not filtered out by `opt_filter`
-
-    // Take AND from all null buffers of `value_columns`.
-    let combined_nulls = value_columns
-        .iter()
-        .map(|arr| arr.logical_nulls())
-        .fold(None, |acc, nulls| {
-            NullBuffer::union(acc.as_ref(), nulls.as_ref())
-        });
-
-    // Take AND from previous combined nulls and `opt_filter`.
-    let valid_indices = match (combined_nulls, opt_filter) {
-        (None, None) => None,
-        (None, Some(filter)) => Some(filter.clone()),
-        (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), 
None)),
-        (Some(nulls), Some(filter)) => {
-            let combined = nulls.inner() & filter.values();
-            Some(BooleanArray::new(combined, None))
-        }
-    };
-
     for col in value_columns.iter() {
         debug_assert_eq!(col.len(), group_indices.len());
     }
 
+    // Start with rows where all value columns are non-null.
+    let mut valid_indices =
+        NullBuffer::union_many(value_columns.iter().map(|arr| arr.nulls()))
+            .map(NullBuffer::into_inner);
+
+    // Restrict to rows where the optional filter is Some(true). Keep the 
filter
+    // as a raw BooleanBuffer to avoid computing a NullBuffer null_count just 
to
+    // test row validity below.
+    if let Some(filter) = opt_filter {
+        debug_assert_eq!(filter.len(), group_indices.len());
+        let filter_validity = filter_to_validity(filter);
+        if let Some(valid_indices) = valid_indices.as_mut() {
+            *valid_indices &= &filter_validity;
+        } else {
+            valid_indices = Some(filter_validity);
+        }
+    }
+
     match valid_indices {
         None => {
             for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
@@ -562,7 +556,8 @@ pub fn accumulate_indices<F>(
         (None, Some(filter)) => {
             debug_assert_eq!(filter.len(), group_indices.len());
             let group_indices_chunks = group_indices.chunks_exact(64);
-            let bit_chunks = filter.values().bit_chunks();
+            let filter_validity = filter_to_validity(filter);
+            let bit_chunks = filter_validity.bit_chunks();
 
             let group_indices_remainder = group_indices_chunks.remainder();
 
@@ -636,7 +631,8 @@ pub fn accumulate_indices<F>(
 
             let group_indices_chunks = group_indices.chunks_exact(64);
             let valid_bit_chunks = valids.inner().bit_chunks();
-            let filter_bit_chunks = filter.values().bit_chunks();
+            let filter_validity = filter_to_validity(filter);
+            let filter_bit_chunks = filter_validity.bit_chunks();
 
             let group_indices_remainder = group_indices_chunks.remainder();
 
@@ -1188,6 +1184,68 @@ mod test {
         assert_eq!(accumulated, expected);
     }
 
+    #[test]
+    fn test_accumulate_indices_with_null_filter() {
+        let group_indices = vec![0, 1, 0, 1];
+        let filter = BooleanArray::new(
+            BooleanBuffer::from(vec![true, true, true, false]),
+            Some(NullBuffer::from(vec![true, false, true, true])),
+        );
+
+        let mut accumulated = vec![];
+        accumulate_indices(&group_indices, None, Some(&filter), |group_idx| {
+            accumulated.push(group_idx);
+        });
+
+        // A NULL filter value should be treated the same as false, even if the
+        // underlying BooleanBuffer value is true.
+        let expected = vec![0, 0];
+        assert_eq!(accumulated, expected);
+
+        let value_validity = NullBuffer::from(vec![true, true, false, true]);
+        let mut accumulated = vec![];
+        accumulate_indices(
+            &group_indices,
+            Some(&value_validity),
+            Some(&filter),
+            |group_idx| {
+                accumulated.push(group_idx);
+            },
+        );
+
+        let expected = vec![0];
+        assert_eq!(accumulated, expected);
+    }
+
+    #[test]
+    fn test_accumulate_multiple_with_null_filter() {
+        let group_indices = vec![0, 1, 0, 1];
+        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
+        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
+        let value_columns = [values1, values2];
+
+        let filter = BooleanArray::new(
+            BooleanBuffer::from(vec![true, true, true, false]),
+            Some(NullBuffer::from(vec![true, false, true, true])),
+        );
+
+        let mut accumulated = vec![];
+        accumulate_multiple(
+            &group_indices,
+            &value_columns.iter().collect::<Vec<_>>(),
+            Some(&filter),
+            |group_idx, batch_idx, columns| {
+                let values = columns.iter().map(|col| 
col.value(batch_idx)).collect();
+                accumulated.push((group_idx, values));
+            },
+        );
+
+        // A NULL filter value should be treated the same as false, even if the
+        // underlying BooleanBuffer value is true.
+        let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
+        assert_eq!(accumulated, expected);
+    }
+
     #[test]
     fn test_accumulate_multiple_with_nulls_and_filter() {
         let group_indices = vec![0, 1, 0, 1];
diff --git 
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
 
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
index 5b56b77e11..d524afe43a 100644
--- 
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
+++ 
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs
@@ -22,7 +22,7 @@ use arrow::array::{
     BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, 
StringArray,
     StringViewArray, StructArray,
 };
-use arrow::buffer::NullBuffer;
+use arrow::buffer::{BooleanBuffer, NullBuffer};
 use arrow::datatypes::DataType;
 use datafusion_common::{Result, not_impl_err};
 use std::sync::Arc;
@@ -39,15 +39,24 @@ pub fn set_nulls<T: ArrowNumericType + Send>(
     PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
 }
 
-/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer.
+/// Converts an aggregate filter expression to a validity bitmap.
+///
+/// The output is `true` for rows where the filter is `Some(true)`, and `false`
+/// for rows where the filter is `Some(false)` or `None`.
+pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
+    let Some(filter_nulls) = filter.nulls() else {
+        return filter.values().clone();
+    };
+    filter.values() & filter_nulls.inner()
+}
+
+/// Converts an aggregate filter expression to a `NullBuffer`.
 ///
 /// The `NullBuffer` is
-/// * `true` (representing valid) for values that were `true` in filter
-/// * `false` (representing null) for values that were `false` or `null` in 
filter
-pub fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
-    let (filter_bools, filter_nulls) = filter.clone().into_parts();
-    let filter_bools = NullBuffer::from(filter_bools);
-    NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
+/// * `true` (representing valid) for filter values that were `Some(true)`
+/// * `false` (representing null) for filter values that were `Some(false)` or 
`None`
+pub fn filter_to_nulls(filter: &BooleanArray) -> NullBuffer {
+    NullBuffer::new(filter_to_validity(filter))
 }
 
 /// Compute an output validity mask for an array that has been filtered
@@ -97,7 +106,7 @@ pub fn filtered_null_mask(
     opt_filter: Option<&BooleanArray>,
     input: &dyn Array,
 ) -> Option<NullBuffer> {
-    let opt_filter = opt_filter.and_then(filter_to_nulls);
+    let opt_filter = opt_filter.map(filter_to_nulls);
     NullBuffer::union(opt_filter.as_ref(), input.nulls())
 }
 
diff --git a/datafusion/functions-aggregate/src/array_agg.rs 
b/datafusion/functions-aggregate/src/array_agg.rs
index 4365c436a5..24edaaff1f 100644
--- a/datafusion/functions-aggregate/src/array_agg.rs
+++ b/datafusion/functions-aggregate/src/array_agg.rs
@@ -776,7 +776,7 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator {
         let offsets = OffsetBuffer::from_repeated_length(1, input.len());
 
         // Filtered rows become null list entries, which merge_batch will skip.
-        let filter_nulls = opt_filter.and_then(filter_to_nulls);
+        let filter_nulls = opt_filter.map(filter_to_nulls);
 
         // With ignore_nulls, null values also become null list entries. 
Without
         // ignore_nulls, null values stay as [NULL] so merge_batch retains 
them.
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 70acff3cb7..b8009dfd57 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -693,6 +693,18 @@ from data
 ----
 1
 
+# correlation_with_group_by_and_nullable_filter
+query IR rowsort
+SELECT g, corr(x, y) FILTER (WHERE b < 1) AS r
+FROM (VALUES
+  (0, 1.0, 1.0, CAST(NULL AS INT)),
+  (0, 2.0, 2.0, CAST(NULL AS INT)),
+  (0, 3.0, 4.0, 2)
+) AS t(g, x, y, b)
+GROUP BY g
+----
+0 NULL
+
 # group correlation_query_with_nans_f32
 query IR
 select id, corr(f, b)
@@ -6177,6 +6189,14 @@ FROM test_table
 ----
 2
 
+# count_with_group_by_and_nullable_filter
+query II rowsort
+SELECT g, COUNT(a) FILTER (WHERE b < 1) AS count_a
+FROM (VALUES (0, 1, CAST(NULL AS INT)), (0, 2, 2)) AS t(g, a, b)
+GROUP BY g
+----
+0 0
+
 # query_with_and_without_filter
 query III rowsort
 SELECT


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to