This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 668063689250cd9a0a7883af8299c4d3bc17f1f1 Author: Andrew Lamb <[email protected]> AuthorDate: Sat Jul 1 05:24:40 2023 -0400 factor out accumulate --- datafusion/physical-expr/src/aggregate/average.rs | 93 ++++++++++++++++++++++- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 7043ed9ce1..0dcff7ec9b 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -417,8 +417,95 @@ impl RowAccumulator for AvgRowAccumulator { } } +/// This function is called once per row to update the accumulator, +/// for a `PrimitiveArray<T>` and is the inner loop for many +/// GroupsAccumulators and thus performance critical. +/// +/// * `values`: the input arguments to the accumulator +/// * `group_indices`: To which groups do the rows in `values` belong, group id) +/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true +/// +/// `F`: The function to invoke for a non null input row to update the +/// accumulator state. Called like `value_fn(group_index, value) +/// +/// `FN`: The function to call for each null input row. Called like +/// `null_fn(group_index) +fn accumulate_all<T, F, FN>( + values: &PrimitiveArray<T>, + group_indicies: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + value_fn: F, + null_fn: FN, +) where + T: ArrowNumericType + Send, + F: Fn(usize, T::Native) + Send, + FN: Fn(usize) + Send, +{ + // AAL TODO handle filter values + // TODO combine the null mask from values and opt_filter + let valids = values.nulls(); + + // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum + let data: &[T::Native] = values.values(); + + match valids { + // no nulls + None => { + let iter = group_indicies.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value) + } + } + // there are nulls, so handle them specially + Some(valids) => { + let group_indices_chunks = group_indicies.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = valids.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real vale + if (mask & index_mask) != 0 { + value_fn(group_index, new_value); + } else { + null_fn(group_index) + } + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the intial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .zip(data_remainder.iter()) + .enumerate() + .for_each(|(i, (&group_index, &new_value))| { + if remainder_bits & (1 << i) != 0 { + value_fn(group_index, new_value) + } else { + null_fn(group_index) + } + }); + } + } +} + /// An accumulator to compute the average of PrimitiveArray<T>. /// Stores values as native types, and does overflow checking +/// +/// F: Function that calcuates the average value from a sum of +/// T::Native and a total count #[derive(Debug)] struct AvgGroupsAccumulator<T, F> where @@ -597,7 +684,7 @@ where let array = PrimitiveArray::<T>::from_iter_values(averages); // fix up decimal precision and scale for decimals - let array = set_decimal_precision(&self.return_data_type, Arc::new(array))?; + let array = adjust_output_array(&self.return_data_type, Arc::new(array))?; Ok(array) } @@ -614,7 +701,7 @@ where let sums: PrimitiveArray<T> = PrimitiveArray::from_iter_values(sums); // fix up decimal precision and scale for decimals - let sums = set_decimal_precision(&self.sum_data_type, Arc::new(sums))?; + let sums = adjust_output_array(&self.sum_data_type, Arc::new(sums))?; Ok(vec