UBarney commented on code in PR #15266:
URL: https://github.com/apache/datafusion/pull/15266#discussion_r2000992780


##########
datafusion/functions-aggregate/src/first_last.rs:
##########
@@ -179,6 +292,423 @@ impl AggregateUDFImpl for FirstValue {
     }
 }
 
+struct FirstGroupsAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+{
+    // ================ state ===========
+    vals: Vec<T::Native>,
+    // Stores ordering values, of the aggregator requirement corresponding to 
first value
+    // of the aggregator.
+    // The `orderings` are stored row-wise, meaning that `orderings[group_idx]`
+    // represents the ordering values corresponding to the `group_idx`-th 
group.
+    orderings: Vec<Vec<ScalarValue>>,
+    // At the beginning, `is_sets[group_idx]` is false, which means `first` is 
not seen yet.
+    // Once we see the first value, we set the `is_sets[group_idx]` flag
+    is_sets: BooleanBufferBuilder,
+    // null_builder[group_idx] == false => vals[group_idx] is null
+    null_builder: BooleanBufferBuilder,
+    // size of `self.orderings`
+    // Calculating the memory usage of `self.orderings` using 
`ScalarValue::size_of_vec` is quite costly.
+    // Therefore, we cache it and compute `size_of` only after each update
+    // to avoid calling `ScalarValue::size_of_vec` by Self.size.
+    size_of_orderings: usize,
+
+    // =========== option ============
+
+    // Stores the applicable ordering requirement.
+    ordering_req: LexOrdering,
+    // derived from `ordering_req`.
+    sort_options: Vec<SortOptions>,
+    // Stores whether incoming data already satisfies the ordering requirement.
+    input_requirement_satisfied: bool,
+    // Ignore null values.
+    ignore_nulls: bool,
+    /// The output type
+    data_type: DataType,
+    default_orderings: Vec<ScalarValue>,
+}
+
+impl<T> FirstGroupsAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+{
+    fn try_new(
+        ordering_req: LexOrdering,
+        ignore_nulls: bool,
+        data_type: &DataType,
+        ordering_dtypes: &[DataType],
+    ) -> Result<Self> {
+        let requirement_satisfied = ordering_req.is_empty();
+
+        let default_orderings = ordering_dtypes
+            .iter()
+            .map(ScalarValue::try_from)
+            .collect::<Result<Vec<_>>>()?;
+
+        let sort_options = get_sort_options(ordering_req.as_ref());
+
+        Ok(Self {
+            null_builder: BooleanBufferBuilder::new(0),
+            ordering_req,
+            sort_options,
+            input_requirement_satisfied: requirement_satisfied,
+            ignore_nulls,
+            default_orderings,
+            data_type: data_type.clone(),
+            vals: Vec::new(),
+            orderings: Vec::new(),
+            is_sets: BooleanBufferBuilder::new(0),
+            size_of_orderings: 0,
+        })
+    }
+
+    fn need_update(&self, group_idx: usize) -> bool {
+        if !self.is_sets.get_bit(group_idx) {
+            return true;
+        }
+
+        if self.ignore_nulls && !self.null_builder.get_bit(group_idx) {
+            return true;
+        }
+
+        !self.input_requirement_satisfied
+    }
+
+    fn should_update_state(
+        &self,
+        group_idx: usize,
+        new_ordering_values: &[ScalarValue],
+    ) -> Result<bool> {
+        if !self.is_sets.get_bit(group_idx) {
+            return Ok(true);
+        }
+
+        assert!(new_ordering_values.len() == self.ordering_req.len());
+        let current_ordering = &self.orderings[group_idx];
+        compare_rows(current_ordering, new_ordering_values, &self.sort_options)
+            .map(|x| x.is_gt())
+    }
+
+    fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
+        let result = emit_to.take_needed(&mut self.orderings);
+
+        match emit_to {
+            EmitTo::All => self.size_of_orderings = 0,
+            EmitTo::First(_) => {
+                self.size_of_orderings -=
+                    result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
+            }
+        }
+
+        result
+    }
+
+    fn take_need(
+        bool_buf_builder: &mut BooleanBufferBuilder,
+        emit_to: EmitTo,
+    ) -> BooleanBuffer {
+        let bool_buf = bool_buf_builder.finish();
+        match emit_to {
+            EmitTo::All => bool_buf,
+            EmitTo::First(n) => {
+                // split off the first N values in seen_values
+                //
+                // TODO make this more efficient rather than two
+                // copies and bitwise manipulation
+                let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
+                // reset the existing buffer
+                for b in bool_buf.iter().skip(n) {
+                    bool_buf_builder.append(b);
+                }
+                first_n
+            }
+        }
+    }
+
+    fn resize_states(&mut self, new_size: usize) {
+        self.vals.resize(new_size, T::default_value());
+
+        if self.null_builder.len() < new_size {
+            self.null_builder
+                .append_n(new_size - self.null_builder.len(), false);
+        }
+
+        if self.orderings.len() < new_size {
+            let current_len = self.orderings.len();
+
+            self.orderings
+                .resize(new_size, self.default_orderings.clone());
+
+            self.size_of_orderings += (new_size - current_len)
+                * ScalarValue::size_of_vec(
+                    // Note: In some cases (such as in the unit test below)
+                    // ScalarValue::size_of_vec(&self.default_orderings) != 
ScalarValue::size_of_vec(&self.default_orderings.clone())
+                    // This may be caused by the different vec.capacity() 
values?
+                    self.orderings.last().unwrap(),
+                );
+        }
+
+        if self.is_sets.len() < new_size {
+            self.is_sets.append_n(new_size - self.is_sets.len(), false);
+        }
+    }
+
+    fn update_state(
+        &mut self,
+        group_idx: usize,
+        orderings: &[ScalarValue],
+        new_val: T::Native,
+        is_null: bool,
+    ) {
+        self.vals[group_idx] = new_val;
+        self.is_sets.set_bit(group_idx, true);
+
+        self.null_builder.set_bit(group_idx, !is_null);
+
+        assert!(orderings.len() == self.ordering_req.len());
+        let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
+        self.orderings[group_idx].clear();
+        self.orderings[group_idx].extend_from_slice(orderings);
+        let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
+        self.size_of_orderings = self.size_of_orderings - old_size + new_size;
+    }
+
+    // should be used in test only
+    #[cfg(test)]
+    fn compute_size_of_orderings(&self) -> usize {
+        self.orderings
+            .iter()
+            .map(ScalarValue::size_of_vec)
+            .sum::<usize>()
+    }
+
+    /// Returns a hashmap where each group (identified by `group_indices`) is 
mapped to
+    /// the index of its minimum value in `orderings`, based on 
lexicographical comparison.
+    /// The function filters values using `opt_filter` and `is_set_arr`
+    fn get_filtered_min_of_each_group(
+        &self,
+        orderings: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        vals: &PrimitiveArray<T>,
+        is_set_arr: Option<&BooleanArray>,
+    ) -> Result<HashMap<usize, usize>> {
+        let mut result = HashMap::with_capacity(orderings.len()); // group_idx 
-> idx_in_orderings
+
+        let comparator = {
+            assert_eq!(orderings.len(), self.ordering_req.len());
+            let sort_columns = orderings
+                .iter()
+                .zip(self.ordering_req.iter())
+                .map(|(array, req)| SortColumn {
+                    values: Arc::clone(array),
+                    options: Some(req.options),
+                })
+                .collect::<Vec<_>>();
+
+            LexicographicalComparator::try_new(&sort_columns)?
+        };
+
+        for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
+            let group_idx = *group_idx;
+
+            let passed_filter = opt_filter.map(|x| 
x.value(idx_in_val)).unwrap_or(true);
+
+            let is_set = is_set_arr.map(|x| 
x.value(idx_in_val)).unwrap_or(true);
+
+            if !passed_filter || !is_set {
+                continue;
+            }
+
+            if !self.need_update(group_idx) {
+                continue;
+            }
+
+            if self.ignore_nulls && vals.is_null(idx_in_val) {
+                continue;
+            }
+
+            if !result.contains_key(&group_idx)
+                || comparator
+                    .compare(*result.get(&group_idx).unwrap(), idx_in_val)
+                    .is_gt()
+            {
+                result.insert(group_idx, idx_in_val);
+            }
+        }
+
+        Ok(result)
+    }
+
+    fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
+        let r = emit_to.take_needed(&mut self.vals);
+
+        let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, 
emit_to));
+
+        let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) // no 
copy
+            .with_data_type(self.data_type.clone());
+        Arc::new(values)
+    }
+}
+
+impl<T> GroupsAccumulator for FirstGroupsAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+{
+    fn update_batch(
+        &mut self,
+        values_with_orderings: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.resize_states(total_num_groups);
+
+        let vals = values_with_orderings[0].as_primitive::<T>();
+
+        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
+
+        for (group_idx, idx) in self

Review Comment:
   Good question.
   
   Because it has better performance when cardinality is low
   benchmark sql: `select l_shipmode, first_value(l_partkey order by 
l_orderkey, l_linenumber, l_comment, l_suppkey, l_tax) from 
'benchmarks/data/tpch_sf10/lineitem' group by l_shipmode;`
   
   
   | version                                                                    
                                                                                
                        | time   |
   | 
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 | ------ |
   | main                                                                       
                                                                                
                        | 0.979s |
   | thisPR                                                                     
                                                                                
                        | 0.86s   |
   | [without 
`get_filtered_min_of_each_group`](https://gist.githubusercontent.com/UBarney/21793799ea91f2c5b556ed0db7be85f5/raw/f5feaf9182c66144b6b48fa41be58f1baccf16de/gistfile1.txt)
 | 1.25s   |
   
   
   `extract_row_at_idx_to_buf` has a relatively high overhead. First call 
`get_filtered_min_of_each_group` to avoid the problem where 
`extract_row_at_idx_to_buf` would be called multiple times when the same 
`group_idx` exists 
   
   ![Pasted image 
20250318203527](https://github.com/user-attachments/assets/22ff1ce8-27cb-41c3-ba6f-4d1dbf8d3260)
   
   
   ![Pasted image 
20250318203645](https://github.com/user-attachments/assets/5de1a4d9-ee6a-4f1d-95ce-3c35ef260a21)
   
   At first, I implemented it like 
[this](https://gist.githubusercontent.com/UBarney/21793799ea91f2c5b556ed0db7be85f5/raw/f5feaf9182c66144b6b48fa41be58f1baccf16de/gistfile1.txt),
 but the performance actually got worse.  
   (At that time, I added the benchmark in 
`datafusion/core/benches/aggregate_query_sql.rs`, and the performance degraded 
from 3.9ms to 7ms.) 😂
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to