This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit c3496ccd6d719c420d7e4dd546e43ca8b1232f5a Author: Andrew Lamb <[email protected]> AuthorDate: Sat Jul 1 06:45:04 2023 -0400 update more comments --- datafusion/physical-expr/src/aggregate/average.rs | 26 ++++---- .../src/aggregate/groups_accumulator/accumulate.rs | 70 ++++++++++++++++------ 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 3f3c7820be..ee249f3bd1 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -467,8 +467,8 @@ where /// Adds one to each group's counter fn increment_counts( &mut self, - values: &PrimitiveArray<T>, group_indicies: &[usize], + values: &PrimitiveArray<T>, opt_filter: Option<&arrow_array::BooleanArray>, total_num_groups: usize, ) { @@ -476,8 +476,8 @@ where if values.null_count() == 0 { accumulate_all( - values, group_indicies, + values, opt_filter, |group_index, _new_value| { self.counts[group_index] += 1; @@ -485,8 +485,8 @@ where ) } else { accumulate_all_nullable( - values, group_indicies, + values, opt_filter, |group_index, _new_value, is_valid| { if is_valid { @@ -500,8 +500,8 @@ where /// Adds the counts with the partial counts fn update_counts_with_partial_counts( &mut self, - partial_counts: &UInt64Array, group_indicies: &[usize], + partial_counts: &UInt64Array, opt_filter: Option<&arrow_array::BooleanArray>, total_num_groups: usize, ) { @@ -509,8 +509,8 @@ where if partial_counts.null_count() == 0 { accumulate_all( - partial_counts, group_indicies, + partial_counts, opt_filter, |group_index, partial_count| { self.counts[group_index] += partial_count; @@ -518,8 +518,8 @@ where ) } else { accumulate_all_nullable( - partial_counts, group_indicies, + partial_counts, opt_filter, |group_index, partial_count, is_valid| { if is_valid { @@ -533,8 +533,8 @@ where /// Adds the values in `values` to self.sums fn update_sums( &mut self, - values: &PrimitiveArray<T>, group_indicies: &[usize], + values: &PrimitiveArray<T>, opt_filter: Option<&arrow_array::BooleanArray>, total_num_groups: usize, ) { @@ -543,8 +543,8 @@ where if values.null_count() == 0 { accumulate_all( - values, group_indicies, + values, opt_filter, |group_index, new_value| { let sum = &mut self.sums[group_index]; @@ -553,8 +553,8 @@ where ) } else { accumulate_all_nullable( - values, group_indicies, + values, opt_filter, |group_index, new_value, is_valid| { if is_valid { @@ -582,8 +582,8 @@ where assert_eq!(values.len(), 1, "single argument to update_batch"); let values = values.get(0).unwrap().as_primitive::<T>(); - self.increment_counts(values, group_indicies, opt_filter, total_num_groups); - self.update_sums(values, group_indicies, opt_filter, total_num_groups); + self.increment_counts(group_indicies, values, opt_filter, total_num_groups); + self.update_sums(group_indicies, values, opt_filter, total_num_groups); Ok(()) } @@ -600,12 +600,12 @@ where let partial_counts = values.get(0).unwrap().as_primitive::<UInt64Type>(); let partial_sums = values.get(1).unwrap().as_primitive::<T>(); self.update_counts_with_partial_counts( - partial_counts, group_indicies, + partial_counts, opt_filter, total_num_groups, ); - self.update_sums(partial_sums, group_indicies, opt_filter, total_num_groups); + self.update_sums(group_indicies, partial_sums, opt_filter, total_num_groups); Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 5d72328763..f8a6791def 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -19,23 +19,55 @@ use arrow_array::{Array, ArrowNumericType, PrimitiveArray}; -/// This function is called to update the accumulator state per row, +/// This function is used to update the accumulator state per row, /// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for /// many GroupsAccumulators and thus performance critical. /// -/// I couldn't find any way to combine this with -/// accumulate_all_nullable without having to pass in a is_null on -/// every row. +/// # Arguments: /// /// * `values`: the input arguments to the accumulator /// * `group_indices`: To which groups do the rows in `values` belong, group id) -/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true +/// * `opt_filter`: if present, invoke value_fn if opt_filter[i] is true +/// * `value_fn`: function invoked for each (group_index, value) pair. +/// +/// `F`: Invoked for each input row like `value_fn(group_index, value) +/// +/// # Example +/// +/// ``` +/// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ +/// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ +/// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ +/// │ └─────┘ │ │ └─────┘ │ └─────┘ +/// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ +/// +/// group_indices values opt_filter +/// ``` +/// +/// In the example above, `value_fn` is invoked for each (group_index, +/// value) pair where `opt_filter[i]` is true +/// +/// ```text +/// value_fn(2, 200) +/// value_fn(0, 200) +/// value_fn(0, 300) +/// ``` +/// +/// I couldn't find any way to combine this with +/// accumulate_all_nullable without having to pass in a is_null on +/// every row. /// -/// `F`: The function to invoke for a non null input row to update the -/// accumulator state. Called like `value_fn(group_index, value) pub fn accumulate_all<T, F>( - values: &PrimitiveArray<T>, group_indicies: &[usize], + values: &PrimitiveArray<T>, opt_filter: Option<&arrow_array::BooleanArray>, mut value_fn: F, ) where @@ -57,19 +89,16 @@ pub fn accumulate_all<T, F>( } /// This function is called to update the accumulator state per row, -/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for -/// many GroupsAccumulators and thus performance critical. +/// for a `PrimitiveArray<T>` that can have nulls. See +/// [`accumulate_all`] for more detail and example /// -/// * `values`: the input arguments to the accumulator -/// * `group_indices`: To which groups do the rows in `values` belong, group id) -/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true +/// `F`: Invoked like `value_fn(group_index, value, is_valid). /// -/// `F`: The function to invoke for an input row to update the -/// accumulator state. Called like `value_fn(group_index, value, -/// is_valid). NOTE the parameter is true when the value is VALID. +/// NOTE the parameter is true when the value is VALID (not when it is +/// NULL). pub fn accumulate_all_nullable<T, F>( - values: &PrimitiveArray<T>, group_indicies: &[usize], + values: &PrimitiveArray<T>, opt_filter: Option<&arrow_array::BooleanArray>, mut value_fn: F, ) where @@ -119,3 +148,10 @@ pub fn accumulate_all_nullable<T, F>( value_fn(group_index, new_value, is_valid) }); } + +#[cfg(test)] +mod test { + + #[test] + fn basic() {} +}
