This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 892e440538b984877b82edcc96b4e2e5f80e70dd
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Jul 1 05:46:52 2023 -0400

    split nullable/non nullable handling
---
 datafusion/physical-expr/src/aggregate/average.rs | 139 ++++++++++++----------
 1 file changed, 79 insertions(+), 60 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 0dcff7ec9b..20ccadd7e8 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -417,9 +417,13 @@ impl RowAccumulator for AvgRowAccumulator {
     }
 }
 
-/// This function is called once per row to update the accumulator,
-/// for a `PrimitiveArray<T>` and is the inner loop for many
-/// GroupsAccumulators and thus performance critical.
+/// This function is called to update the accumulator state per row,
+/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for
+/// many GroupsAccumulators and thus performance critical.
+///
+/// I couldn't find any way to combine this with
+/// accumulate_all_nullable without having to pass in a is_null on
+/// every row.
 ///
 /// * `values`: the input arguments to the accumulator
 /// * `group_indices`:  To which groups do the rows in `values` belong, group 
id)
@@ -427,80 +431,95 @@ impl RowAccumulator for AvgRowAccumulator {
 ///
 /// `F`: The function to invoke for a non null input row to update the
 /// accumulator state. Called like `value_fn(group_index, value)
-///
-/// `FN`: The function to call for each null input row.  Called like
-/// `null_fn(group_index)
 fn accumulate_all<T, F, FN>(
     values: &PrimitiveArray<T>,
     group_indicies: &[usize],
     opt_filter: Option<&arrow_array::BooleanArray>,
     value_fn: F,
-    null_fn: FN,
 ) where
     T: ArrowNumericType + Send,
     F: Fn(usize, T::Native) + Send,
-    FN: Fn(usize) + Send,
 {
+    assert_eq!(
+        values.null_count(), 0,
+        "Called accumulate_all with nullable array (call 
accumulate_all_nullable instead)"
+    );
+
     // AAL TODO handle filter values
+
+    let data: &[T::Native] = values.values();
+    let iter = group_indicies.iter().zip(data.iter());
+    for (&group_index, &new_value) in iter {
+        value_fn(group_index, new_value)
+    }
+}
+
+
+/// This function is called to update the accumulator state per row,
+/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for
+/// many GroupsAccumulators and thus performance critical.
+///
+/// * `values`: the input arguments to the accumulator
+/// * `group_indices`:  To which groups do the rows in `values` belong, group 
id)
+/// * `opt_filter`: if present, only update aggregate state using values[i] if 
opt_filter[i] is true
+///
+/// `F`: The function to invoke for an input row to update the
+/// accumulator state. Called like `value_fn(group_index, value,
+/// is_valid). NOTE the parameter is true when the value is VALID.
+fn accumulate_all_nullable<T, F, FN>(
+    values: &PrimitiveArray<T>,
+    group_indicies: &[usize],
+    opt_filter: Option<&arrow_array::BooleanArray>,
+    value_fn: F,
+) where
+    T: ArrowNumericType + Send,
+    F: Fn(usize, T::Native, bool) + Send,
+{
+     // AAL TODO handle filter values
     // TODO combine the null mask from values and opt_filter
-    let valids = values.nulls();
+    let valids = values
+        .nulls()
+        .expect("Called accumulate_all_nullable with non-nullable array (call 
accumulate_all instead)");
 
     // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
     let data: &[T::Native] = values.values();
 
-    match valids {
-        // no nulls
-        None => {
-            let iter = group_indicies.iter().zip(data.iter());
-            for (&group_index, &new_value) in iter {
-                value_fn(group_index, new_value)
-            }
-        }
-        // there are nulls, so handle them specially
-        Some(valids) => {
-            let group_indices_chunks = group_indicies.chunks_exact(64);
-            let data_chunks = data.chunks_exact(64);
-            let bit_chunks = valids.inner().bit_chunks();
-
-            let group_indices_remainder = group_indices_chunks.remainder();
-            let data_remainder = data_chunks.remainder();
-
-            group_indices_chunks
-                .zip(data_chunks)
-                .zip(bit_chunks.iter())
-                .for_each(|((group_index_chunk, data_chunk), mask)| {
-                    // index_mask has value 1 << i in the loop
-                    let mut index_mask = 1;
-                    group_index_chunk.iter().zip(data_chunk.iter()).for_each(
-                        |(&group_index, &new_value)| {
-                            // valid bit was set, real vale
-                            if (mask & index_mask) != 0 {
-                                value_fn(group_index, new_value);
-                            } else {
-                                null_fn(group_index)
-                            }
-                            index_mask <<= 1;
-                        },
-                    )
-                });
-
-            // handle any remaining bits (after the intial 64)
-            let remainder_bits = bit_chunks.remainder_bits();
-            group_indices_remainder
-                .iter()
-                .zip(data_remainder.iter())
-                .enumerate()
-                .for_each(|(i, (&group_index, &new_value))| {
-                    if remainder_bits & (1 << i) != 0 {
-                        value_fn(group_index, new_value)
-                    } else {
-                        null_fn(group_index)
-                    }
-                });
-        }
-    }
+    let group_indices_chunks = group_indicies.chunks_exact(64);
+    let data_chunks = data.chunks_exact(64);
+    let bit_chunks = valids.inner().bit_chunks();
+
+    let group_indices_remainder = group_indices_chunks.remainder();
+    let data_remainder = data_chunks.remainder();
+
+    group_indices_chunks
+        .zip(data_chunks)
+        .zip(bit_chunks.iter())
+        .for_each(|((group_index_chunk, data_chunk), mask)| {
+            // index_mask has value 1 << i in the loop
+            let mut index_mask = 1;
+            group_index_chunk.iter().zip(data_chunk.iter()).for_each(
+                |(&group_index, &new_value)| {
+                    // valid bit was set, real vale
+                    let is_valid = (mask & index_mask) != 0;
+                    value_fn(group_index, new_value, is_valid);
+                    index_mask <<= 1;
+                },
+            )
+        });
+
+    // handle any remaining bits (after the intial 64)
+    let remainder_bits = bit_chunks.remainder_bits();
+    group_indices_remainder
+        .iter()
+        .zip(data_remainder.iter())
+        .enumerate()
+        .for_each(|(i, (&group_index, &new_value))| {
+            let is_valid = remainder_bits & (1 << i) != 0;
+            value_fn(group_index, new_value, is_valid)
+        });
 }
 
+
 /// An accumulator to compute the average of PrimitiveArray<T>.
 /// Stores values as native types, and does overflow checking
 ///

Reply via email to