UBarney commented on code in PR #15266: URL: https://github.com/apache/datafusion/pull/15266#discussion_r2000992780
########## datafusion/functions-aggregate/src/first_last.rs: ########## @@ -179,6 +292,423 @@ impl AggregateUDFImpl for FirstValue { } } +struct FirstGroupsAccumulator<T> +where + T: ArrowPrimitiveType + Send, +{ + // ================ state =========== + vals: Vec<T::Native>, + // Stores ordering values, of the aggregator requirement corresponding to first value + // of the aggregator. + // The `orderings` are stored row-wise, meaning that `orderings[group_idx]` + // represents the ordering values corresponding to the `group_idx`-th group. + orderings: Vec<Vec<ScalarValue>>, + // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet. + // Once we see the first value, we set the `is_sets[group_idx]` flag + is_sets: BooleanBufferBuilder, + // null_builder[group_idx] == false => vals[group_idx] is null + null_builder: BooleanBufferBuilder, + // size of `self.orderings` + // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly. + // Therefore, we cache it and compute `size_of` only after each update + // to avoid calling `ScalarValue::size_of_vec` by Self.size. + size_of_orderings: usize, + + // =========== option ============ + + // Stores the applicable ordering requirement. + ordering_req: LexOrdering, + // derived from `ordering_req`. + sort_options: Vec<SortOptions>, + // Stores whether incoming data already satisfies the ordering requirement. + input_requirement_satisfied: bool, + // Ignore null values. + ignore_nulls: bool, + /// The output type + data_type: DataType, + default_orderings: Vec<ScalarValue>, +} + +impl<T> FirstGroupsAccumulator<T> +where + T: ArrowPrimitiveType + Send, +{ + fn try_new( + ordering_req: LexOrdering, + ignore_nulls: bool, + data_type: &DataType, + ordering_dtypes: &[DataType], + ) -> Result<Self> { + let requirement_satisfied = ordering_req.is_empty(); + + let default_orderings = ordering_dtypes + .iter() + .map(ScalarValue::try_from) + .collect::<Result<Vec<_>>>()?; + + let sort_options = get_sort_options(ordering_req.as_ref()); + + Ok(Self { + null_builder: BooleanBufferBuilder::new(0), + ordering_req, + sort_options, + input_requirement_satisfied: requirement_satisfied, + ignore_nulls, + default_orderings, + data_type: data_type.clone(), + vals: Vec::new(), + orderings: Vec::new(), + is_sets: BooleanBufferBuilder::new(0), + size_of_orderings: 0, + }) + } + + fn need_update(&self, group_idx: usize) -> bool { + if !self.is_sets.get_bit(group_idx) { + return true; + } + + if self.ignore_nulls && !self.null_builder.get_bit(group_idx) { + return true; + } + + !self.input_requirement_satisfied + } + + fn should_update_state( + &self, + group_idx: usize, + new_ordering_values: &[ScalarValue], + ) -> Result<bool> { + if !self.is_sets.get_bit(group_idx) { + return Ok(true); + } + + assert!(new_ordering_values.len() == self.ordering_req.len()); + let current_ordering = &self.orderings[group_idx]; + compare_rows(current_ordering, new_ordering_values, &self.sort_options) + .map(|x| x.is_gt()) + } + + fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> { + let result = emit_to.take_needed(&mut self.orderings); + + match emit_to { + EmitTo::All => self.size_of_orderings = 0, + EmitTo::First(_) => { + self.size_of_orderings -= + result.iter().map(ScalarValue::size_of_vec).sum::<usize>() + } + } + + result + } + + fn take_need( + bool_buf_builder: &mut BooleanBufferBuilder, + emit_to: EmitTo, + ) -> BooleanBuffer { + let bool_buf = bool_buf_builder.finish(); + match emit_to { + EmitTo::All => bool_buf, + EmitTo::First(n) => { + // split off the first N values in seen_values + // + // TODO make this more efficient rather than two + // copies and bitwise manipulation + let first_n: BooleanBuffer = bool_buf.iter().take(n).collect(); + // reset the existing buffer + for b in bool_buf.iter().skip(n) { + bool_buf_builder.append(b); + } + first_n + } + } + } + + fn resize_states(&mut self, new_size: usize) { + self.vals.resize(new_size, T::default_value()); + + if self.null_builder.len() < new_size { + self.null_builder + .append_n(new_size - self.null_builder.len(), false); + } + + if self.orderings.len() < new_size { + let current_len = self.orderings.len(); + + self.orderings + .resize(new_size, self.default_orderings.clone()); + + self.size_of_orderings += (new_size - current_len) + * ScalarValue::size_of_vec( + // Note: In some cases (such as in the unit test below) + // ScalarValue::size_of_vec(&self.default_orderings) != ScalarValue::size_of_vec(&self.default_orderings.clone()) + // This may be caused by the different vec.capacity() values? + self.orderings.last().unwrap(), + ); + } + + if self.is_sets.len() < new_size { + self.is_sets.append_n(new_size - self.is_sets.len(), false); + } + } + + fn update_state( + &mut self, + group_idx: usize, + orderings: &[ScalarValue], + new_val: T::Native, + is_null: bool, + ) { + self.vals[group_idx] = new_val; + self.is_sets.set_bit(group_idx, true); + + self.null_builder.set_bit(group_idx, !is_null); + + assert!(orderings.len() == self.ordering_req.len()); + let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); + self.orderings[group_idx].clear(); + self.orderings[group_idx].extend_from_slice(orderings); + let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); + self.size_of_orderings = self.size_of_orderings - old_size + new_size; + } + + // should be used in test only + #[cfg(test)] + fn compute_size_of_orderings(&self) -> usize { + self.orderings + .iter() + .map(ScalarValue::size_of_vec) + .sum::<usize>() + } + + /// Returns a hashmap where each group (identified by `group_indices`) is mapped to + /// the index of its minimum value in `orderings`, based on lexicographical comparison. + /// The function filters values using `opt_filter` and `is_set_arr` + fn get_filtered_min_of_each_group( + &self, + orderings: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + vals: &PrimitiveArray<T>, + is_set_arr: Option<&BooleanArray>, + ) -> Result<HashMap<usize, usize>> { + let mut result = HashMap::with_capacity(orderings.len()); // group_idx -> idx_in_orderings + + let comparator = { + assert_eq!(orderings.len(), self.ordering_req.len()); + let sort_columns = orderings + .iter() + .zip(self.ordering_req.iter()) + .map(|(array, req)| SortColumn { + values: Arc::clone(array), + options: Some(req.options), + }) + .collect::<Vec<_>>(); + + LexicographicalComparator::try_new(&sort_columns)? + }; + + for (idx_in_val, group_idx) in group_indices.iter().enumerate() { + let group_idx = *group_idx; + + let passed_filter = opt_filter.map(|x| x.value(idx_in_val)).unwrap_or(true); + + let is_set = is_set_arr.map(|x| x.value(idx_in_val)).unwrap_or(true); + + if !passed_filter || !is_set { + continue; + } + + if !self.need_update(group_idx) { + continue; + } + + if self.ignore_nulls && vals.is_null(idx_in_val) { + continue; + } + + if !result.contains_key(&group_idx) + || comparator + .compare(*result.get(&group_idx).unwrap(), idx_in_val) + .is_gt() + { + result.insert(group_idx, idx_in_val); + } + } + + Ok(result) + } + + fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef { + let r = emit_to.take_needed(&mut self.vals); + + let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to)); + + let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) // no copy + .with_data_type(self.data_type.clone()); + Arc::new(values) + } +} + +impl<T> GroupsAccumulator for FirstGroupsAccumulator<T> +where + T: ArrowPrimitiveType + Send, +{ + fn update_batch( + &mut self, + values_with_orderings: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.resize_states(total_num_groups); + + let vals = values_with_orderings[0].as_primitive::<T>(); + + let mut ordering_buf = Vec::with_capacity(self.ordering_req.len()); + + for (group_idx, idx) in self Review Comment: > Inside this function, it seems to > 'compress' the current input batch with get_filtered_min_of_each_group() (if there are multiple entries for the same group, only keep the smallest one according to the specified order) Update the global state for the minimal value corresponding to all seen groups Yes. You are right. > Why is it split into two steps instead of directly updating the global state? According to [this ](https://datafusion.apache.org/user-guide/sql/aggregate_functions.html#first-value) > Returns the first element in an aggregation group according to the requested ordering The reason for splitting it into two steps is that it performs better when cardinality is low. benchmark sql: `select l_shipmode, first_value(l_partkey order by l_orderkey, l_linenumber, l_comment, l_suppkey, l_tax) from 'benchmarks/data/tpch_sf10/lineitem' group by l_shipmode;` | version | time | | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | | main | 0.979s | | thisPR | 0.86s | | [without `get_filtered_min_of_each_group`](https://gist.githubusercontent.com/UBarney/21793799ea91f2c5b556ed0db7be85f5/raw/f5feaf9182c66144b6b48fa41be58f1baccf16de/gistfile1.txt) | 1.25s | `extract_row_at_idx_to_buf` has a relatively high overhead. First call `get_filtered_min_of_each_group` to avoid the problem where `extract_row_at_idx_to_buf` would be called multiple times when the same `group_idx` exists   At first, I implemented it like [this](https://gist.githubusercontent.com/UBarney/21793799ea91f2c5b556ed0db7be85f5/raw/f5feaf9182c66144b6b48fa41be58f1baccf16de/gistfile1.txt), but the performance actually got worse. (At that time, I added the benchmark in `datafusion/core/benches/aggregate_query_sql.rs`, and the performance degraded from 3.9ms to 7ms.) 😂 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org