This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 0cfa1bc3c9c17d34f2a5eaf22b4269d31da100d3 Author: Andrew Lamb <[email protected]> AuthorDate: Sat Jul 1 06:07:44 2023 -0400 Refactor out accumulation in average --- datafusion/physical-expr/src/aggregate/average.rs | 189 ++++++++++++---------- 1 file changed, 105 insertions(+), 84 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 20ccadd7e8..2d9a627a5f 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -431,14 +431,14 @@ impl RowAccumulator for AvgRowAccumulator { /// /// `F`: The function to invoke for a non null input row to update the /// accumulator state. Called like `value_fn(group_index, value) -fn accumulate_all<T, F, FN>( +fn accumulate_all<T, F>( values: &PrimitiveArray<T>, group_indicies: &[usize], opt_filter: Option<&arrow_array::BooleanArray>, - value_fn: F, + mut value_fn: F, ) where T: ArrowNumericType + Send, - F: Fn(usize, T::Native) + Send, + F: FnMut(usize, T::Native) + Send, { assert_eq!( values.null_count(), 0, @@ -454,7 +454,6 @@ fn accumulate_all<T, F, FN>( } } - /// This function is called to update the accumulator state per row, /// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for /// many GroupsAccumulators and thus performance critical. @@ -466,16 +465,16 @@ fn accumulate_all<T, F, FN>( /// `F`: The function to invoke for an input row to update the /// accumulator state. Called like `value_fn(group_index, value, /// is_valid). NOTE the parameter is true when the value is VALID. -fn accumulate_all_nullable<T, F, FN>( +fn accumulate_all_nullable<T, F>( values: &PrimitiveArray<T>, group_indicies: &[usize], opt_filter: Option<&arrow_array::BooleanArray>, - value_fn: F, + mut value_fn: F, ) where T: ArrowNumericType + Send, - F: Fn(usize, T::Native, bool) + Send, + F: FnMut(usize, T::Native, bool) + Send, { - // AAL TODO handle filter values + // AAL TODO handle filter values // TODO combine the null mask from values and opt_filter let valids = values .nulls() @@ -519,7 +518,6 @@ fn accumulate_all_nullable<T, F, FN>( }); } - /// An accumulator to compute the average of PrimitiveArray<T>. /// Stores values as native types, and does overflow checking /// @@ -566,6 +564,72 @@ where } } + /// Adds one to each group's counter + fn increment_counts( + &mut self, + values: &PrimitiveArray<T>, + group_indicies: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if values.null_count() == 0 { + accumulate_all( + values, + group_indicies, + opt_filter, + |group_index, _new_value| { + self.counts[group_index] += 1; + }, + ) + } else { + accumulate_all_nullable( + values, + group_indicies, + opt_filter, + |group_index, _new_value, is_valid| { + if is_valid { + self.counts[group_index] += 1; + } + }, + ) + } + } + + /// Adds the counts with the partial counts + fn update_counts_with_partial_counts( + &mut self, + partial_counts: &UInt64Array, + group_indicies: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if partial_counts.null_count() == 0 { + accumulate_all( + partial_counts, + group_indicies, + opt_filter, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ) + } else { + accumulate_all_nullable( + partial_counts, + group_indicies, + opt_filter, + |group_index, partial_count, is_valid| { + if is_valid { + self.counts[group_index] += partial_count; + } + }, + ) + } + } + /// Adds the values in `values` to self.sums fn update_sums( &mut self, @@ -573,66 +637,33 @@ where group_indicies: &[usize], opt_filter: Option<&arrow_array::BooleanArray>, total_num_groups: usize, - ) -> Result<()> { + ) { self.sums .resize_with(total_num_groups, || T::default_value()); - // AAL TODO - // TODO combine the null mask from values and opt_filter - let valids = values.nulls(); - - // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum - let data: &[T::Native] = values.values(); - - match valids { - // use all values in group_index - None => { - let iter = group_indicies.iter().zip(data.iter()); - for (group_index, new_value) in iter { - let sum = &mut self.sums[*group_index]; - *sum = sum.add_wrapping(*new_value); - } - } - // - Some(valids) => { - let group_indices_chunks = group_indicies.chunks_exact(64); - let data_chunks = data.chunks_exact(64); - let bit_chunks = valids.inner().bit_chunks(); - - let group_indices_remainder = group_indices_chunks.remainder(); - let data_remainder = data_chunks.remainder(); - - group_indices_chunks - .zip(data_chunks) - .zip(bit_chunks.iter()) - .for_each(|((group_index_chunk, data_chunk), mask)| { - // index_mask has value 1 << i in the loop - let mut index_mask = 1; - group_index_chunk.iter().zip(data_chunk.iter()).for_each( - |(group_index, new_value)| { - if (mask & index_mask) != 0 { - let sum = &mut self.sums[*group_index]; - *sum = sum.add_wrapping(*new_value); - } - index_mask <<= 1; - }, - ) - }); - - let remainder_bits = bit_chunks.remainder_bits(); - group_indices_remainder - .iter() - .zip(data_remainder.iter()) - .enumerate() - .for_each(|(i, (group_index, new_value))| { - if remainder_bits & (1 << i) != 0 { - let sum = &mut self.sums[*group_index]; - *sum = sum.add_wrapping(*new_value); - } - }); - } + if values.null_count() == 0 { + accumulate_all( + values, + group_indicies, + opt_filter, + |group_index, new_value| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + }, + ) + } else { + accumulate_all_nullable( + values, + group_indicies, + opt_filter, + |group_index, new_value, is_valid| { + if is_valid { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + } + }, + ) } - Ok(()) } } @@ -651,14 +682,8 @@ where assert_eq!(values.len(), 1, "single argument to update_batch"); let values = values.get(0).unwrap().as_primitive::<T>(); - // update counts (TOD account for opt_filter) - self.counts.resize(total_num_groups, 0); - group_indicies.iter().for_each(|&group_idx| { - self.counts[group_idx] += 1; - }); - - // update values - self.update_sums(values, group_indicies, opt_filter, total_num_groups)?; + self.increment_counts(values, group_indicies, opt_filter, total_num_groups); + self.update_sums(values, group_indicies, opt_filter, total_num_groups); Ok(()) } @@ -672,19 +697,15 @@ where ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); // first batch is counts, second is partial sums - let counts = values.get(0).unwrap().as_primitive::<UInt64Type>(); + let partial_counts = values.get(0).unwrap().as_primitive::<UInt64Type>(); let partial_sums = values.get(1).unwrap().as_primitive::<T>(); - - // update counts by summing the partial sums (TODO account for opt_filter) - self.counts.resize(total_num_groups, 0); - group_indicies.iter().zip(counts.values().iter()).for_each( - |(&group_idx, &count)| { - self.counts[group_idx] += count; - }, + self.update_counts_with_partial_counts( + partial_counts, + group_indicies, + opt_filter, + total_num_groups, ); - - // update values - self.update_sums(partial_sums, group_indicies, opt_filter, total_num_groups)?; + self.update_sums(partial_sums, group_indicies, opt_filter, total_num_groups); Ok(()) }
