This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit c3496ccd6d719c420d7e4dd546e43ca8b1232f5a
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Jul 1 06:45:04 2023 -0400

    update more comments
---
 datafusion/physical-expr/src/aggregate/average.rs  | 26 ++++----
 .../src/aggregate/groups_accumulator/accumulate.rs | 70 ++++++++++++++++------
 2 files changed, 66 insertions(+), 30 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 3f3c7820be..ee249f3bd1 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -467,8 +467,8 @@ where
     /// Adds one to each group's counter
     fn increment_counts(
         &mut self,
-        values: &PrimitiveArray<T>,
         group_indicies: &[usize],
+        values: &PrimitiveArray<T>,
         opt_filter: Option<&arrow_array::BooleanArray>,
         total_num_groups: usize,
     ) {
@@ -476,8 +476,8 @@ where
 
         if values.null_count() == 0 {
             accumulate_all(
-                values,
                 group_indicies,
+                values,
                 opt_filter,
                 |group_index, _new_value| {
                     self.counts[group_index] += 1;
@@ -485,8 +485,8 @@ where
             )
         } else {
             accumulate_all_nullable(
-                values,
                 group_indicies,
+                values,
                 opt_filter,
                 |group_index, _new_value, is_valid| {
                     if is_valid {
@@ -500,8 +500,8 @@ where
     /// Adds the counts with the partial counts
     fn update_counts_with_partial_counts(
         &mut self,
-        partial_counts: &UInt64Array,
         group_indicies: &[usize],
+        partial_counts: &UInt64Array,
         opt_filter: Option<&arrow_array::BooleanArray>,
         total_num_groups: usize,
     ) {
@@ -509,8 +509,8 @@ where
 
         if partial_counts.null_count() == 0 {
             accumulate_all(
-                partial_counts,
                 group_indicies,
+                partial_counts,
                 opt_filter,
                 |group_index, partial_count| {
                     self.counts[group_index] += partial_count;
@@ -518,8 +518,8 @@ where
             )
         } else {
             accumulate_all_nullable(
-                partial_counts,
                 group_indicies,
+                partial_counts,
                 opt_filter,
                 |group_index, partial_count, is_valid| {
                     if is_valid {
@@ -533,8 +533,8 @@ where
     /// Adds the values in `values` to self.sums
     fn update_sums(
         &mut self,
-        values: &PrimitiveArray<T>,
         group_indicies: &[usize],
+        values: &PrimitiveArray<T>,
         opt_filter: Option<&arrow_array::BooleanArray>,
         total_num_groups: usize,
     ) {
@@ -543,8 +543,8 @@ where
 
         if values.null_count() == 0 {
             accumulate_all(
-                values,
                 group_indicies,
+                values,
                 opt_filter,
                 |group_index, new_value| {
                     let sum = &mut self.sums[group_index];
@@ -553,8 +553,8 @@ where
             )
         } else {
             accumulate_all_nullable(
-                values,
                 group_indicies,
+                values,
                 opt_filter,
                 |group_index, new_value, is_valid| {
                     if is_valid {
@@ -582,8 +582,8 @@ where
         assert_eq!(values.len(), 1, "single argument to update_batch");
         let values = values.get(0).unwrap().as_primitive::<T>();
 
-        self.increment_counts(values, group_indicies, opt_filter, 
total_num_groups);
-        self.update_sums(values, group_indicies, opt_filter, total_num_groups);
+        self.increment_counts(group_indicies, values, opt_filter, 
total_num_groups);
+        self.update_sums(group_indicies, values, opt_filter, total_num_groups);
 
         Ok(())
     }
@@ -600,12 +600,12 @@ where
         let partial_counts = 
values.get(0).unwrap().as_primitive::<UInt64Type>();
         let partial_sums = values.get(1).unwrap().as_primitive::<T>();
         self.update_counts_with_partial_counts(
-            partial_counts,
             group_indicies,
+            partial_counts,
             opt_filter,
             total_num_groups,
         );
-        self.update_sums(partial_sums, group_indicies, opt_filter, 
total_num_groups);
+        self.update_sums(group_indicies, partial_sums, opt_filter, 
total_num_groups);
 
         Ok(())
     }
diff --git 
a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs 
b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs
index 5d72328763..f8a6791def 100644
--- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs
+++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs
@@ -19,23 +19,55 @@
 
 use arrow_array::{Array, ArrowNumericType, PrimitiveArray};
 
-/// This function is called to update the accumulator state per row,
+/// This function is used to update the accumulator state per row,
 /// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for
 /// many GroupsAccumulators and thus performance critical.
 ///
-/// I couldn't find any way to combine this with
-/// accumulate_all_nullable without having to pass in a is_null on
-/// every row.
+/// # Arguments:
 ///
 /// * `values`: the input arguments to the accumulator
 /// * `group_indices`:  To which groups do the rows in `values` belong, group 
id)
-/// * `opt_filter`: if present, only update aggregate state using values[i] if 
opt_filter[i] is true
+/// * `opt_filter`: if present, invoke value_fn if opt_filter[i] is true
+/// * `value_fn`: function invoked for each (group_index, value) pair.
+///
+/// `F`: Invoked for each input row like `value_fn(group_index, value)
+///
+/// # Example
+///
+/// ```
+///  ┌─────────┐   ┌─────────┐   ┌ ─ ─ ─ ─ ┐
+///  │ ┌─────┐ │   │ ┌─────┐ │     ┌─────┐
+///  │ │  2  │ │   │ │ 200 │ │   │ │  t  │ │
+///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
+///  │ │  2  │ │   │ │ 100 │ │   │ │  f  │ │
+///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
+///  │ │  0  │ │   │ │ 200 │ │   │ │  t  │ │
+///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
+///  │ │  1  │ │   │ │ 200 │ │   │ │NULL │ │
+///  │ ├─────┤ │   │ ├─────┤ │     ├─────┤
+///  │ │  0  │ │   │ │ 300 │ │   │ │  t  │ │
+///  │ └─────┘ │   │ └─────┘ │     └─────┘
+///  └─────────┘   └─────────┘   └ ─ ─ ─ ─ ┘
+///
+/// group_indices   values        opt_filter
+/// ```
+///
+/// In the example above, `value_fn` is invoked for each (group_index,
+/// value) pair where `opt_filter[i]` is true
+///
+/// ```text
+/// value_fn(2, 200)
+/// value_fn(0, 200)
+/// value_fn(0, 300)
+/// ```
+///
+/// I couldn't find any way to combine this with
+/// accumulate_all_nullable without having to pass in a is_null on
+/// every row.
 ///
-/// `F`: The function to invoke for a non null input row to update the
-/// accumulator state. Called like `value_fn(group_index, value)
 pub fn accumulate_all<T, F>(
-    values: &PrimitiveArray<T>,
     group_indicies: &[usize],
+    values: &PrimitiveArray<T>,
     opt_filter: Option<&arrow_array::BooleanArray>,
     mut value_fn: F,
 ) where
@@ -57,19 +89,16 @@ pub fn accumulate_all<T, F>(
 }
 
 /// This function is called to update the accumulator state per row,
-/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for
-/// many GroupsAccumulators and thus performance critical.
+/// for a `PrimitiveArray<T>` that can have nulls. See
+/// [`accumulate_all`] for more detail and example
 ///
-/// * `values`: the input arguments to the accumulator
-/// * `group_indices`:  To which groups do the rows in `values` belong, group 
id)
-/// * `opt_filter`: if present, only update aggregate state using values[i] if 
opt_filter[i] is true
+/// `F`: Invoked like `value_fn(group_index, value, is_valid).
 ///
-/// `F`: The function to invoke for an input row to update the
-/// accumulator state. Called like `value_fn(group_index, value,
-/// is_valid). NOTE the parameter is true when the value is VALID.
+/// NOTE the parameter is true when the value is VALID (not when it is
+/// NULL).
 pub fn accumulate_all_nullable<T, F>(
-    values: &PrimitiveArray<T>,
     group_indicies: &[usize],
+    values: &PrimitiveArray<T>,
     opt_filter: Option<&arrow_array::BooleanArray>,
     mut value_fn: F,
 ) where
@@ -119,3 +148,10 @@ pub fn accumulate_all_nullable<T, F>(
             value_fn(group_index, new_value, is_valid)
         });
 }
+
+#[cfg(test)]
+mod test {
+
+    #[test]
+    fn basic() {}
+}

Reply via email to