This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 0cfa1bc3c9c17d34f2a5eaf22b4269d31da100d3
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Jul 1 06:07:44 2023 -0400

    Refactor out accumulation in average
---
 datafusion/physical-expr/src/aggregate/average.rs | 189 ++++++++++++----------
 1 file changed, 105 insertions(+), 84 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 20ccadd7e8..2d9a627a5f 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -431,14 +431,14 @@ impl RowAccumulator for AvgRowAccumulator {
 ///
 /// `F`: The function to invoke for a non null input row to update the
 /// accumulator state. Called like `value_fn(group_index, value)
-fn accumulate_all<T, F, FN>(
+fn accumulate_all<T, F>(
     values: &PrimitiveArray<T>,
     group_indicies: &[usize],
     opt_filter: Option<&arrow_array::BooleanArray>,
-    value_fn: F,
+    mut value_fn: F,
 ) where
     T: ArrowNumericType + Send,
-    F: Fn(usize, T::Native) + Send,
+    F: FnMut(usize, T::Native) + Send,
 {
     assert_eq!(
         values.null_count(), 0,
@@ -454,7 +454,6 @@ fn accumulate_all<T, F, FN>(
     }
 }
 
-
 /// This function is called to update the accumulator state per row,
 /// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for
 /// many GroupsAccumulators and thus performance critical.
@@ -466,16 +465,16 @@ fn accumulate_all<T, F, FN>(
 /// `F`: The function to invoke for an input row to update the
 /// accumulator state. Called like `value_fn(group_index, value,
 /// is_valid). NOTE the parameter is true when the value is VALID.
-fn accumulate_all_nullable<T, F, FN>(
+fn accumulate_all_nullable<T, F>(
     values: &PrimitiveArray<T>,
     group_indicies: &[usize],
     opt_filter: Option<&arrow_array::BooleanArray>,
-    value_fn: F,
+    mut value_fn: F,
 ) where
     T: ArrowNumericType + Send,
-    F: Fn(usize, T::Native, bool) + Send,
+    F: FnMut(usize, T::Native, bool) + Send,
 {
-     // AAL TODO handle filter values
+    // AAL TODO handle filter values
     // TODO combine the null mask from values and opt_filter
     let valids = values
         .nulls()
@@ -519,7 +518,6 @@ fn accumulate_all_nullable<T, F, FN>(
         });
 }
 
-
 /// An accumulator to compute the average of PrimitiveArray<T>.
 /// Stores values as native types, and does overflow checking
 ///
@@ -566,6 +564,72 @@ where
         }
     }
 
+    /// Adds one to each group's counter
+    fn increment_counts(
+        &mut self,
+        values: &PrimitiveArray<T>,
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) {
+        self.counts.resize(total_num_groups, 0);
+
+        if values.null_count() == 0 {
+            accumulate_all(
+                values,
+                group_indicies,
+                opt_filter,
+                |group_index, _new_value| {
+                    self.counts[group_index] += 1;
+                },
+            )
+        } else {
+            accumulate_all_nullable(
+                values,
+                group_indicies,
+                opt_filter,
+                |group_index, _new_value, is_valid| {
+                    if is_valid {
+                        self.counts[group_index] += 1;
+                    }
+                },
+            )
+        }
+    }
+
+    /// Adds the counts with the partial counts
+    fn update_counts_with_partial_counts(
+        &mut self,
+        partial_counts: &UInt64Array,
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) {
+        self.counts.resize(total_num_groups, 0);
+
+        if partial_counts.null_count() == 0 {
+            accumulate_all(
+                partial_counts,
+                group_indicies,
+                opt_filter,
+                |group_index, partial_count| {
+                    self.counts[group_index] += partial_count;
+                },
+            )
+        } else {
+            accumulate_all_nullable(
+                partial_counts,
+                group_indicies,
+                opt_filter,
+                |group_index, partial_count, is_valid| {
+                    if is_valid {
+                        self.counts[group_index] += partial_count;
+                    }
+                },
+            )
+        }
+    }
+
     /// Adds the values in `values` to self.sums
     fn update_sums(
         &mut self,
@@ -573,66 +637,33 @@ where
         group_indicies: &[usize],
         opt_filter: Option<&arrow_array::BooleanArray>,
         total_num_groups: usize,
-    ) -> Result<()> {
+    ) {
         self.sums
             .resize_with(total_num_groups, || T::default_value());
 
-        // AAL TODO
-        // TODO combine the null mask from values and opt_filter
-        let valids = values.nulls();
-
-        // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
-        let data: &[T::Native] = values.values();
-
-        match valids {
-            // use all values in group_index
-            None => {
-                let iter = group_indicies.iter().zip(data.iter());
-                for (group_index, new_value) in iter {
-                    let sum = &mut self.sums[*group_index];
-                    *sum = sum.add_wrapping(*new_value);
-                }
-            }
-            //
-            Some(valids) => {
-                let group_indices_chunks = group_indicies.chunks_exact(64);
-                let data_chunks = data.chunks_exact(64);
-                let bit_chunks = valids.inner().bit_chunks();
-
-                let group_indices_remainder = group_indices_chunks.remainder();
-                let data_remainder = data_chunks.remainder();
-
-                group_indices_chunks
-                    .zip(data_chunks)
-                    .zip(bit_chunks.iter())
-                    .for_each(|((group_index_chunk, data_chunk), mask)| {
-                        // index_mask has value 1 << i in the loop
-                        let mut index_mask = 1;
-                        
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
-                            |(group_index, new_value)| {
-                                if (mask & index_mask) != 0 {
-                                    let sum = &mut self.sums[*group_index];
-                                    *sum = sum.add_wrapping(*new_value);
-                                }
-                                index_mask <<= 1;
-                            },
-                        )
-                    });
-
-                let remainder_bits = bit_chunks.remainder_bits();
-                group_indices_remainder
-                    .iter()
-                    .zip(data_remainder.iter())
-                    .enumerate()
-                    .for_each(|(i, (group_index, new_value))| {
-                        if remainder_bits & (1 << i) != 0 {
-                            let sum = &mut self.sums[*group_index];
-                            *sum = sum.add_wrapping(*new_value);
-                        }
-                    });
-            }
+        if values.null_count() == 0 {
+            accumulate_all(
+                values,
+                group_indicies,
+                opt_filter,
+                |group_index, new_value| {
+                    let sum = &mut self.sums[group_index];
+                    *sum = sum.add_wrapping(new_value);
+                },
+            )
+        } else {
+            accumulate_all_nullable(
+                values,
+                group_indicies,
+                opt_filter,
+                |group_index, new_value, is_valid| {
+                    if is_valid {
+                        let sum = &mut self.sums[group_index];
+                        *sum = sum.add_wrapping(new_value);
+                    }
+                },
+            )
         }
-        Ok(())
     }
 }
 
@@ -651,14 +682,8 @@ where
         assert_eq!(values.len(), 1, "single argument to update_batch");
         let values = values.get(0).unwrap().as_primitive::<T>();
 
-        // update counts (TOD account for opt_filter)
-        self.counts.resize(total_num_groups, 0);
-        group_indicies.iter().for_each(|&group_idx| {
-            self.counts[group_idx] += 1;
-        });
-
-        // update values
-        self.update_sums(values, group_indicies, opt_filter, 
total_num_groups)?;
+        self.increment_counts(values, group_indicies, opt_filter, 
total_num_groups);
+        self.update_sums(values, group_indicies, opt_filter, total_num_groups);
 
         Ok(())
     }
@@ -672,19 +697,15 @@ where
     ) -> Result<()> {
         assert_eq!(values.len(), 2, "two arguments to merge_batch");
         // first batch is counts, second is partial sums
-        let counts = values.get(0).unwrap().as_primitive::<UInt64Type>();
+        let partial_counts = 
values.get(0).unwrap().as_primitive::<UInt64Type>();
         let partial_sums = values.get(1).unwrap().as_primitive::<T>();
-
-        // update counts by summing the partial sums (TODO account for 
opt_filter)
-        self.counts.resize(total_num_groups, 0);
-        group_indicies.iter().zip(counts.values().iter()).for_each(
-            |(&group_idx, &count)| {
-                self.counts[group_idx] += count;
-            },
+        self.update_counts_with_partial_counts(
+            partial_counts,
+            group_indicies,
+            opt_filter,
+            total_num_groups,
         );
-
-        // update values
-        self.update_sums(partial_sums, group_indicies, opt_filter, 
total_num_groups)?;
+        self.update_sums(partial_sums, group_indicies, opt_filter, 
total_num_groups);
 
         Ok(())
     }

Reply via email to