mustafasrepo commented on code in PR #6904: URL: https://github.com/apache/arrow-datafusion/pull/6904#discussion_r1259930680
########## datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs: ########## @@ -0,0 +1,1042 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] +//! +//! [`GroupsAccumulator`]: crate::GroupsAccumulator + +use arrow::datatypes::ArrowPrimitiveType; +use arrow_array::{Array, BooleanArray, PrimitiveArray}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + +/// Track the accumulator null state per row: if any values for that +/// group were null and if any values have been seen at all for that group. +/// +/// This is part of the inner loop for many [`GroupsAccumulator`]s, +/// and thus the performance is critical and so there are multiple +/// specialized implementations, invoked depending on the specific +/// combinations of the input. +/// +/// Typically there are 4 potential combinations of inputs must be +/// special caseed for performance: +/// +/// * With / Without filter +/// * With / Without nulls in the input +/// +/// If the input has nulls, then the accumulator must potentially +/// handle each input null value specially (e.g. for `SUM` to mark the +/// corresponding sum as null) +/// +/// If there are filters present, `NullState` tracks if it has seen +/// *any* value for that group (as some values may be filtered +/// out). Without a filter, the accumulator is only passed groups that +/// had at least one value to accumulate so they do not need to track +/// if they have seen values for a particular group. +/// +/// [`GroupsAccumulator`]: crate::GroupsAccumulator +#[derive(Debug)] +pub struct NullState { + /// Tracks if a null input value has been seen for `group_index`, + /// if there were any nulls in the input. + /// + /// If `null_inputs[i]` is true, have not seen any null values for + /// group `i`, or have not seen any vaues + /// + /// If `null_inputs[i]` is false, saw at least one null value for + /// group `i` + null_inputs: Option<BooleanBufferBuilder>, + + /// If there has been an `opt_filter`, has it seen any + /// non-filtered input values for `group_index`? + /// + /// If `seen_values[i]` is true, have seen at least one non null + /// value for group `i` + /// + /// If `seen_values[i]` is false, have not seen any values that + /// pass the filter yet for group `i` + seen_values: Option<BooleanBufferBuilder>, +} + +impl NullState { + pub fn new() -> Self { + Self { + null_inputs: None, + seen_values: None, + } + } + + /// return the size of all buffers allocated by this null state, not including self + pub fn size(&self) -> usize { + builder_size(self.null_inputs.as_ref()) + builder_size(self.seen_values.as_ref()) + } + + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value of `value`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs if necessary + // + /// # Arguments: + /// + /// * `values`: the input arguments to the accumulator + /// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) + /// * `opt_filter`: if present, only rows for which is Some(true) are included + /// * `value_fn`: function invoked for (group_index, value) where value is non null + /// + /// # Example + /// + /// ```text + /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ + /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ + /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ + /// │ └─────┘ │ │ └─────┘ │ └─────┘ + /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ + /// + /// group_indices values opt_filter + /// ``` + /// + /// In the example above, `value_fn` is invoked for each (group_index, + /// value) pair where `opt_filter[i]` is true and values is non null + /// + /// ```text + /// value_fn(2, 200) + /// value_fn(0, 200) + /// value_fn(0, 300) + /// ``` + /// + /// It also sets + /// + /// 1. `self.seen_values[group_index]` to true for all rows that had a value if there is a filter + /// + /// 2. `self.null_inputs[group_index]` to true for all rows that had a null in input + pub fn accumulate<T, F>( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray<T>, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, + { + let data: &[T::Native] = values.values(); + assert_eq!(data.len(), group_indices.len()); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + // if we have previously seen nulls, ensure the null + // buffer is big enough (start everything at valid) + if self.null_inputs.is_some() { + initialize_builder(&mut self.null_inputs, total_num_groups, true); + } + let iter = group_indices.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value) + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + // All groups start as valid (true), and are set to + // null if we see a null in the input) + let null_inputs = + initialize_builder(&mut self.null_inputs, total_num_groups, true); + + // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = nulls.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + value_fn(group_index, new_value); + } else { + // input null means this group is now null + null_inputs.set_bit(group_index, false); + } + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the intial 64) Review Comment: ```suggestion // handle any remaining bits (after the initial 64) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
