This is an automated email from the ASF dual-hosted git repository.

ytyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 923bfb7fc7 Improve performance of `first_value` by implementing 
special `GroupsAccumulator` (#15266)
923bfb7fc7 is described below

commit 923bfb7fc7cf1718522f572b74f3756d02652933
Author: UBarney <[email protected]>
AuthorDate: Wed Mar 26 10:56:36 2025 +0800

    Improve performance of `first_value` by implementing special 
`GroupsAccumulator` (#15266)
    
    * Improve speed of first_value by implementing special GroupsAccumulator
    
    * rename and other improvements
    
    * `append_n` -> `resize`
    
    * address comment
    
    * use HashMap::entry
    
    * remove hashMap in get_filtered_min_of_each_group
---
 datafusion/common/src/utils/mod.rs                 |  17 +
 datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs |  27 +-
 .../aggregation_fuzzer/data_generator.rs           |   4 +
 .../tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs  |  66 +-
 datafusion/functions-aggregate/src/first_last.rs   | 754 ++++++++++++++++++++-
 datafusion/sqllogictest/test_files/group_by.slt    |  41 ++
 6 files changed, 897 insertions(+), 12 deletions(-)

diff --git a/datafusion/common/src/utils/mod.rs 
b/datafusion/common/src/utils/mod.rs
index ff9cdedab8..409f248621 100644
--- a/datafusion/common/src/utils/mod.rs
+++ b/datafusion/common/src/utils/mod.rs
@@ -82,6 +82,23 @@ pub fn project_schema(
     Ok(schema)
 }
 
+/// Extracts a row at the specified index from a set of columns and stores it 
in the provided buffer.
+pub fn extract_row_at_idx_to_buf(
+    columns: &[ArrayRef],
+    idx: usize,
+    buf: &mut Vec<ScalarValue>,
+) -> Result<()> {
+    buf.clear();
+
+    let iter = columns
+        .iter()
+        .map(|arr| ScalarValue::try_from_array(arr, idx));
+    for v in iter.into_iter() {
+        buf.push(v?);
+    }
+
+    Ok(())
+}
 /// Given column vectors, returns row at `idx`.
 pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> 
Result<Vec<ScalarValue>> {
     columns
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index 46221acfcc..1b98a19581 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -15,7 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::str;
 use std::sync::Arc;
 
 use crate::fuzz_cases::aggregation_fuzzer::{
@@ -88,6 +87,32 @@ async fn test_min() {
         .await;
 }
 
+#[tokio::test(flavor = "multi_thread")]
+async fn test_first_val() {
+    let mut data_gen_config: DatasetGeneratorConfig = baseline_config();
+
+    for i in 0..data_gen_config.columns.len() {
+        if data_gen_config.columns[i].get_max_num_distinct().is_none() {
+            data_gen_config.columns[i] = data_gen_config.columns[i]
+                .clone()
+                // Minimize the chance of identical values in the order by 
columns to make the test more stable
+                .with_max_num_distinct(usize::MAX);
+        }
+    }
+
+    let query_builder = QueryBuilder::new()
+        .with_table_name("fuzz_table")
+        .with_aggregate_function("first_value")
+        .with_aggregate_arguments(data_gen_config.all_columns())
+        .set_group_by_columns(data_gen_config.all_columns());
+
+    AggregationFuzzerBuilder::from(data_gen_config)
+        .add_query_builder(query_builder)
+        .build()
+        .run()
+        .await;
+}
+
 #[tokio::test(flavor = "multi_thread")]
 async fn test_max() {
     let data_gen_config = baseline_config();
diff --git 
a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs 
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
index 54c5744c86..d61835a080 100644
--- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
@@ -228,6 +228,10 @@ impl ColumnDescr {
         }
     }
 
+    pub fn get_max_num_distinct(&self) -> Option<usize> {
+        self.max_num_distinct
+    }
+
     /// set the maximum number of distinct values in this column
     ///
     /// If `None`, the number of distinct values is randomly selected between 1
diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs 
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
index c608adda5d..bb24fb554d 100644
--- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
@@ -15,13 +15,14 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::collections::HashSet;
 use std::sync::Arc;
+use std::{collections::HashSet, str::FromStr};
 
 use arrow::array::RecordBatch;
 use arrow::util::pretty::pretty_format_batches;
 use datafusion_common::{DataFusionError, Result};
 use datafusion_common_runtime::JoinSet;
+use rand::seq::SliceRandom;
 use rand::{thread_rng, Rng};
 
 use crate::fuzz_cases::aggregation_fuzzer::{
@@ -452,7 +453,11 @@ impl QueryBuilder {
     pub fn generate_query(&self) -> String {
         let group_by = self.random_group_by();
         let mut query = String::from("SELECT ");
-        query.push_str(&self.random_aggregate_functions().join(", "));
+        query.push_str(&group_by.join(", "));
+        if !group_by.is_empty() {
+            query.push_str(", ");
+        }
+        query.push_str(&self.random_aggregate_functions(&group_by).join(", "));
         query.push_str(" FROM ");
         query.push_str(&self.table_name);
         if !group_by.is_empty() {
@@ -474,7 +479,7 @@ impl QueryBuilder {
     /// * `function_names` are randomly selected from 
[`Self::aggregate_functions`]
     /// * `<DISTINCT> argument` is randomly selected from [`Self::arguments`]
     /// * `alias` is a unique alias `colN` for the column (to avoid duplicate 
column names)
-    fn random_aggregate_functions(&self) -> Vec<String> {
+    fn random_aggregate_functions(&self, group_by_cols: &[String]) -> 
Vec<String> {
         const MAX_NUM_FUNCTIONS: usize = 5;
         let mut rng = thread_rng();
         let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS);
@@ -482,6 +487,14 @@ impl QueryBuilder {
         let mut alias_gen = 1;
 
         let mut aggregate_functions = vec![];
+
+        let mut order_by_black_list: HashSet<String> =
+            group_by_cols.iter().cloned().collect();
+        // remove one random col
+        if let Some(first) = order_by_black_list.iter().next().cloned() {
+            order_by_black_list.remove(&first);
+        }
+
         while aggregate_functions.len() < num_aggregate_functions {
             let idx = rng.gen_range(0..self.aggregate_functions.len());
             let (function_name, is_distinct) = &self.aggregate_functions[idx];
@@ -489,7 +502,19 @@ impl QueryBuilder {
             let alias = format!("col{}", alias_gen);
             let distinct = if *is_distinct { "DISTINCT " } else { "" };
             alias_gen += 1;
-            let function = format!("{function_name}({distinct}{argument}) as 
{alias}");
+
+            let (order_by, null_opt) = if function_name.eq("first_value") {
+                (
+                    self.order_by(&order_by_black_list), /* Among the order by 
columns, at most one group by column can be included to avoid all order by 
column values being identical */
+                    self.null_opt(),
+                )
+            } else {
+                ("".to_string(), "".to_string())
+            };
+
+            let function = format!(
+                "{function_name}({distinct}{argument}{order_by}) {null_opt} as 
{alias}"
+            );
             aggregate_functions.push(function);
         }
         aggregate_functions
@@ -502,6 +527,39 @@ impl QueryBuilder {
         self.arguments[idx].clone()
     }
 
+    fn order_by(&self, black_list: &HashSet<String>) -> String {
+        let mut available_columns: Vec<String> = self
+            .arguments
+            .iter()
+            .filter(|col| !black_list.contains(*col))
+            .cloned()
+            .collect();
+
+        available_columns.shuffle(&mut thread_rng());
+
+        let num_of_order_by_col = 12;
+        let column_count = std::cmp::min(num_of_order_by_col, 
available_columns.len());
+
+        let selected_columns = &available_columns[0..column_count];
+
+        let mut rng = thread_rng();
+        let mut result = String::from_str(" order by ").unwrap();
+        for col in selected_columns {
+            let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" };
+            result.push_str(&format!("{} {},", col, order));
+        }
+
+        result.strip_suffix(",").unwrap().to_string()
+    }
+
+    fn null_opt(&self) -> String {
+        if thread_rng().gen_bool(0.5) {
+            "RESPECT NULLS".to_string()
+        } else {
+            "IGNORE NULLS".to_string()
+        }
+    }
+
     /// Pick a random number of fields to group by (non-repeating)
     ///
     /// Limited to 3 group by columns to ensure coverage for large groups. With
diff --git a/datafusion/functions-aggregate/src/first_last.rs 
b/datafusion/functions-aggregate/src/first_last.rs
index 6df8ede4fc..28e6a8723d 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -22,18 +22,30 @@ use std::fmt::Debug;
 use std::mem::size_of_val;
 use std::sync::Arc;
 
-use arrow::array::{ArrayRef, AsArray, BooleanArray};
-use arrow::compute::{self, LexicographicalComparator, SortColumn};
-use arrow::datatypes::{DataType, Field};
-use datafusion_common::utils::{compare_rows, get_row_at_idx};
+use arrow::array::{
+    Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, 
BooleanBufferBuilder,
+    PrimitiveArray,
+};
+use arrow::buffer::{BooleanBuffer, NullBuffer};
+use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
+use arrow::datatypes::{
+    DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, 
Float16Type,
+    Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
+    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, 
Time64NanosecondType,
+    TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
+    TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, 
UInt64Type,
+    UInt8Type,
+};
+use datafusion_common::cast::as_boolean_array;
+use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, 
get_row_at_idx};
 use datafusion_common::{
     arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
 };
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
 use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
 use datafusion_expr::{
-    Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, 
Signature,
-    SortExpr, Volatility,
+    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, 
ExprFunctionExt,
+    GroupsAccumulator, Signature, SortExpr, Volatility,
 };
 use datafusion_functions_aggregate_common::utils::get_sort_options;
 use datafusion_macros::user_doc;
@@ -153,6 +165,106 @@ impl AggregateUDFImpl for FirstValue {
         Ok(fields)
     }
 
+    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        use DataType::*;
+        matches!(
+            args.return_type,
+            Int8 | Int16
+                | Int32
+                | Int64
+                | UInt8
+                | UInt16
+                | UInt32
+                | UInt64
+                | Float16
+                | Float32
+                | Float64
+                | Decimal128(_, _)
+                | Decimal256(_, _)
+                | Date32
+                | Date64
+                | Time32(_)
+                | Time64(_)
+                | Timestamp(_, _)
+        )
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        fn create_accumulator<T>(
+            args: AccumulatorArgs,
+        ) -> Result<Box<dyn GroupsAccumulator>>
+        where
+            T: ArrowPrimitiveType + Send,
+        {
+            let ordering_dtypes = args
+                .ordering_req
+                .iter()
+                .map(|e| e.expr.data_type(args.schema))
+                .collect::<Result<Vec<_>>>()?;
+
+            Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
+                args.ordering_req.clone(),
+                args.ignore_nulls,
+                args.return_type,
+                &ordering_dtypes,
+            )?))
+        }
+
+        match args.return_type {
+            DataType::Int8 => create_accumulator::<Int8Type>(args),
+            DataType::Int16 => create_accumulator::<Int16Type>(args),
+            DataType::Int32 => create_accumulator::<Int32Type>(args),
+            DataType::Int64 => create_accumulator::<Int64Type>(args),
+            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
+            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
+            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
+            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
+            DataType::Float16 => create_accumulator::<Float16Type>(args),
+            DataType::Float32 => create_accumulator::<Float32Type>(args),
+            DataType::Float64 => create_accumulator::<Float64Type>(args),
+
+            DataType::Decimal128(_, _) => 
create_accumulator::<Decimal128Type>(args),
+            DataType::Decimal256(_, _) => 
create_accumulator::<Decimal256Type>(args),
+
+            DataType::Timestamp(TimeUnit::Second, _) => {
+                create_accumulator::<TimestampSecondType>(args)
+            }
+            DataType::Timestamp(TimeUnit::Millisecond, _) => {
+                create_accumulator::<TimestampMillisecondType>(args)
+            }
+            DataType::Timestamp(TimeUnit::Microsecond, _) => {
+                create_accumulator::<TimestampMicrosecondType>(args)
+            }
+            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+                create_accumulator::<TimestampNanosecondType>(args)
+            }
+
+            DataType::Date32 => create_accumulator::<Date32Type>(args),
+            DataType::Date64 => create_accumulator::<Date64Type>(args),
+            DataType::Time32(TimeUnit::Second) => {
+                create_accumulator::<Time32SecondType>(args)
+            }
+            DataType::Time32(TimeUnit::Millisecond) => {
+                create_accumulator::<Time32MillisecondType>(args)
+            }
+
+            DataType::Time64(TimeUnit::Microsecond) => {
+                create_accumulator::<Time64MicrosecondType>(args)
+            }
+            DataType::Time64(TimeUnit::Nanosecond) => {
+                create_accumulator::<Time64NanosecondType>(args)
+            }
+
+            _ => internal_err!(
+                "GroupsAccumulator not supported for first({})",
+                args.return_type
+            ),
+        }
+    }
+
     fn aliases(&self) -> &[String] {
         &[]
     }
@@ -179,6 +291,460 @@ impl AggregateUDFImpl for FirstValue {
     }
 }
 
+struct FirstPrimitiveGroupsAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+{
+    // ================ state ===========
+    vals: Vec<T::Native>,
+    // Stores ordering values, of the aggregator requirement corresponding to 
first value
+    // of the aggregator.
+    // The `orderings` are stored row-wise, meaning that `orderings[group_idx]`
+    // represents the ordering values corresponding to the `group_idx`-th 
group.
+    orderings: Vec<Vec<ScalarValue>>,
+    // At the beginning, `is_sets[group_idx]` is false, which means `first` is 
not seen yet.
+    // Once we see the first value, we set the `is_sets[group_idx]` flag
+    is_sets: BooleanBufferBuilder,
+    // null_builder[group_idx] == false => vals[group_idx] is null
+    null_builder: BooleanBufferBuilder,
+    // size of `self.orderings`
+    // Calculating the memory usage of `self.orderings` using 
`ScalarValue::size_of_vec` is quite costly.
+    // Therefore, we cache it and compute `size_of` only after each update
+    // to avoid calling `ScalarValue::size_of_vec` by Self.size.
+    size_of_orderings: usize,
+
+    // buffer for `get_filtered_min_of_each_group`
+    // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
+    // only valid if filter_min_of_each_group_buf.1[group_idx] == true
+    min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
+
+    // =========== option ============
+
+    // Stores the applicable ordering requirement.
+    ordering_req: LexOrdering,
+    // derived from `ordering_req`.
+    sort_options: Vec<SortOptions>,
+    // Stores whether incoming data already satisfies the ordering requirement.
+    input_requirement_satisfied: bool,
+    // Ignore null values.
+    ignore_nulls: bool,
+    /// The output type
+    data_type: DataType,
+    default_orderings: Vec<ScalarValue>,
+}
+
+impl<T> FirstPrimitiveGroupsAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+{
+    fn try_new(
+        ordering_req: LexOrdering,
+        ignore_nulls: bool,
+        data_type: &DataType,
+        ordering_dtypes: &[DataType],
+    ) -> Result<Self> {
+        let requirement_satisfied = ordering_req.is_empty();
+
+        let default_orderings = ordering_dtypes
+            .iter()
+            .map(ScalarValue::try_from)
+            .collect::<Result<Vec<_>>>()?;
+
+        let sort_options = get_sort_options(ordering_req.as_ref());
+
+        Ok(Self {
+            null_builder: BooleanBufferBuilder::new(0),
+            ordering_req,
+            sort_options,
+            input_requirement_satisfied: requirement_satisfied,
+            ignore_nulls,
+            default_orderings,
+            data_type: data_type.clone(),
+            vals: Vec::new(),
+            orderings: Vec::new(),
+            is_sets: BooleanBufferBuilder::new(0),
+            size_of_orderings: 0,
+            min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
+        })
+    }
+
+    fn need_update(&self, group_idx: usize) -> bool {
+        if !self.is_sets.get_bit(group_idx) {
+            return true;
+        }
+
+        if self.ignore_nulls && !self.null_builder.get_bit(group_idx) {
+            return true;
+        }
+
+        !self.input_requirement_satisfied
+    }
+
+    fn should_update_state(
+        &self,
+        group_idx: usize,
+        new_ordering_values: &[ScalarValue],
+    ) -> Result<bool> {
+        if !self.is_sets.get_bit(group_idx) {
+            return Ok(true);
+        }
+
+        assert!(new_ordering_values.len() == self.ordering_req.len());
+        let current_ordering = &self.orderings[group_idx];
+        compare_rows(current_ordering, new_ordering_values, &self.sort_options)
+            .map(|x| x.is_gt())
+    }
+
+    fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
+        let result = emit_to.take_needed(&mut self.orderings);
+
+        match emit_to {
+            EmitTo::All => self.size_of_orderings = 0,
+            EmitTo::First(_) => {
+                self.size_of_orderings -=
+                    result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
+            }
+        }
+
+        result
+    }
+
+    fn take_need(
+        bool_buf_builder: &mut BooleanBufferBuilder,
+        emit_to: EmitTo,
+    ) -> BooleanBuffer {
+        let bool_buf = bool_buf_builder.finish();
+        match emit_to {
+            EmitTo::All => bool_buf,
+            EmitTo::First(n) => {
+                // split off the first N values in seen_values
+                //
+                // TODO make this more efficient rather than two
+                // copies and bitwise manipulation
+                let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
+                // reset the existing buffer
+                for b in bool_buf.iter().skip(n) {
+                    bool_buf_builder.append(b);
+                }
+                first_n
+            }
+        }
+    }
+
+    fn resize_states(&mut self, new_size: usize) {
+        self.vals.resize(new_size, T::default_value());
+
+        self.null_builder.resize(new_size);
+
+        if self.orderings.len() < new_size {
+            let current_len = self.orderings.len();
+
+            self.orderings
+                .resize(new_size, self.default_orderings.clone());
+
+            self.size_of_orderings += (new_size - current_len)
+                * ScalarValue::size_of_vec(
+                    // Note: In some cases (such as in the unit test below)
+                    // ScalarValue::size_of_vec(&self.default_orderings) != 
ScalarValue::size_of_vec(&self.default_orderings.clone())
+                    // This may be caused by the different vec.capacity() 
values?
+                    self.orderings.last().unwrap(),
+                );
+        }
+
+        self.is_sets.resize(new_size);
+
+        self.min_of_each_group_buf.0.resize(new_size, 0);
+        self.min_of_each_group_buf.1.resize(new_size);
+    }
+
+    fn update_state(
+        &mut self,
+        group_idx: usize,
+        orderings: &[ScalarValue],
+        new_val: T::Native,
+        is_null: bool,
+    ) {
+        self.vals[group_idx] = new_val;
+        self.is_sets.set_bit(group_idx, true);
+
+        self.null_builder.set_bit(group_idx, !is_null);
+
+        assert!(orderings.len() == self.ordering_req.len());
+        let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
+        self.orderings[group_idx].clear();
+        self.orderings[group_idx].extend_from_slice(orderings);
+        let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
+        self.size_of_orderings = self.size_of_orderings - old_size + new_size;
+    }
+
+    fn take_state(
+        &mut self,
+        emit_to: EmitTo,
+    ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
+        emit_to.take_needed(&mut self.min_of_each_group_buf.0);
+        self.min_of_each_group_buf
+            .1
+            .truncate(self.min_of_each_group_buf.0.len());
+
+        (
+            self.take_vals_and_null_buf(emit_to),
+            self.take_orderings(emit_to),
+            Self::take_need(&mut self.is_sets, emit_to),
+        )
+    }
+
+    // should be used in test only
+    #[cfg(test)]
+    fn compute_size_of_orderings(&self) -> usize {
+        self.orderings
+            .iter()
+            .map(ScalarValue::size_of_vec)
+            .sum::<usize>()
+    }
+
+    /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the 
index of the
+    /// minimum value in `orderings` for each group, using lexicographical 
comparison.
+    /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
+    fn get_filtered_min_of_each_group(
+        &mut self,
+        orderings: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        vals: &PrimitiveArray<T>,
+        is_set_arr: Option<&BooleanArray>,
+    ) -> Result<Vec<(usize, usize)>> {
+        // Set all values in min_of_each_group_buf.1 to false.
+        self.min_of_each_group_buf.1.truncate(0);
+        self.min_of_each_group_buf
+            .1
+            .append_n(self.vals.len(), false);
+
+        // No need to call `clear` since 
`self.min_of_each_group_buf.0[group_idx]`
+        // is only valid when `self.min_of_each_group_buf.1[group_idx] == 
true`.
+
+        let comparator = {
+            assert_eq!(orderings.len(), self.ordering_req.len());
+            let sort_columns = orderings
+                .iter()
+                .zip(self.ordering_req.iter())
+                .map(|(array, req)| SortColumn {
+                    values: Arc::clone(array),
+                    options: Some(req.options),
+                })
+                .collect::<Vec<_>>();
+
+            LexicographicalComparator::try_new(&sort_columns)?
+        };
+
+        for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
+            let group_idx = *group_idx;
+
+            let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
+
+            let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
+
+            if !passed_filter || !is_set {
+                continue;
+            }
+
+            if !self.need_update(group_idx) {
+                continue;
+            }
+
+            if self.ignore_nulls && vals.is_null(idx_in_val) {
+                continue;
+            }
+
+            let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
+            if is_valid
+                && comparator
+                    .compare(self.min_of_each_group_buf.0[group_idx], 
idx_in_val)
+                    .is_gt()
+            {
+                self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+            } else if !is_valid {
+                self.min_of_each_group_buf.1.set_bit(group_idx, true);
+                self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+            }
+        }
+
+        Ok(self
+            .min_of_each_group_buf
+            .0
+            .iter()
+            .enumerate()
+            .filter(|(group_idx, _)| 
self.min_of_each_group_buf.1.get_bit(*group_idx))
+            .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
+            .collect::<Vec<_>>())
+    }
+
+    fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
+        let r = emit_to.take_needed(&mut self.vals);
+
+        let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, 
emit_to));
+
+        let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) // no 
copy
+            .with_data_type(self.data_type.clone());
+        Arc::new(values)
+    }
+}
+
+impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
+where
+    T: ArrowPrimitiveType + Send,
+{
+    fn update_batch(
+        &mut self,
+        // e.g. first_value(a order by b): values_and_order_cols will be [a, b]
+        values_and_order_cols: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.resize_states(total_num_groups);
+
+        let vals = values_and_order_cols[0].as_primitive::<T>();
+
+        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
+
+        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat 
high, so we need to minimize its calls as much as possible.
+        for (group_idx, idx) in self
+            .get_filtered_min_of_each_group(
+                &values_and_order_cols[1..],
+                group_indices,
+                opt_filter,
+                vals,
+                None,
+            )?
+            .into_iter()
+        {
+            extract_row_at_idx_to_buf(
+                &values_and_order_cols[1..],
+                idx,
+                &mut ordering_buf,
+            )?;
+
+            if self.should_update_state(group_idx, &ordering_buf)? {
+                self.update_state(
+                    group_idx,
+                    &ordering_buf,
+                    vals.value(idx),
+                    vals.is_null(idx),
+                );
+            }
+        }
+
+        Ok(())
+    }
+
+    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+        Ok(self.take_state(emit_to).0)
+    }
+
+    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+        let (val_arr, orderings, is_sets) = self.take_state(emit_to);
+        let mut result = Vec::with_capacity(self.orderings.len() + 2);
+
+        result.push(val_arr);
+
+        let ordering_cols = {
+            let mut ordering_cols = 
Vec::with_capacity(self.ordering_req.len());
+            for _ in 0..self.ordering_req.len() {
+                ordering_cols.push(Vec::with_capacity(self.orderings.len()));
+            }
+            for row in orderings.into_iter() {
+                assert_eq!(row.len(), self.ordering_req.len());
+                for (col_idx, ordering) in row.into_iter().enumerate() {
+                    ordering_cols[col_idx].push(ordering);
+                }
+            }
+
+            ordering_cols
+        };
+        for ordering_col in ordering_cols {
+            result.push(ScalarValue::iter_to_array(ordering_col)?);
+        }
+
+        result.push(Arc::new(BooleanArray::new(is_sets, None)));
+
+        Ok(result)
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.resize_states(total_num_groups);
+
+        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
+
+        let (is_set_arr, val_and_order_cols) = match values.split_last() {
+            Some(result) => result,
+            None => return internal_err!("Empty row in FISRT_VALUE"),
+        };
+
+        let is_set_arr = as_boolean_array(is_set_arr)?;
+
+        let vals = values[0].as_primitive::<T>();
+        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat 
high, so we need to minimize its calls as much as possible.
+        let groups = self.get_filtered_min_of_each_group(
+            &val_and_order_cols[1..],
+            group_indices,
+            opt_filter,
+            vals,
+            Some(is_set_arr),
+        )?;
+
+        for (group_idx, idx) in groups.into_iter() {
+            extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut 
ordering_buf)?;
+
+            if self.should_update_state(group_idx, &ordering_buf)? {
+                self.update_state(
+                    group_idx,
+                    &ordering_buf,
+                    vals.value(idx),
+                    vals.is_null(idx),
+                );
+            }
+        }
+
+        Ok(())
+    }
+
+    fn size(&self) -> usize {
+        self.vals.capacity() * size_of::<T::Native>()
+            + self.null_builder.capacity() / 8 // capacity is in bits, so 
convert to bytes 
+            + self.is_sets.capacity() / 8
+            + self.size_of_orderings
+            + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
+            + self.min_of_each_group_buf.1.capacity() / 8
+    }
+
+    fn supports_convert_to_state(&self) -> bool {
+        true
+    }
+
+    fn convert_to_state(
+        &self,
+        values: &[ArrayRef],
+        opt_filter: Option<&BooleanArray>,
+    ) -> Result<Vec<ArrayRef>> {
+        let mut result = values.to_vec();
+        match opt_filter {
+            Some(f) => {
+                result.push(Arc::new(f.clone()));
+                Ok(result)
+            }
+            None => {
+                result.push(Arc::new(BooleanArray::from(vec![true; 
values[0].len()])));
+                Ok(result)
+            }
+        }
+    }
+}
 #[derive(Debug)]
 pub struct FirstValueAccumulator {
     first: ScalarValue,
@@ -684,7 +1250,8 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: 
&LexOrdering) -> Vec<Sort
 
 #[cfg(test)]
 mod tests {
-    use arrow::array::Int64Array;
+    use arrow::{array::Int64Array, compute::SortOptions, datatypes::Schema};
+    use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
 
     use super::*;
 
@@ -823,4 +1390,177 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_frist_group_acc() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int64, true),
+            Field::new("b", DataType::Int64, true),
+            Field::new("c", DataType::Int64, true),
+            Field::new("d", DataType::Int32, true),
+            Field::new("e", DataType::Boolean, true),
+        ]));
+
+        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
+            expr: col("c", &schema).unwrap(),
+            options: SortOptions::default(),
+        }]);
+
+        let mut group_acc = 
FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
+            sort_key,
+            true,
+            &DataType::Int64,
+            &[DataType::Int64],
+        )?;
+
+        let mut val_with_orderings = {
+            let mut val_with_orderings = Vec::<ArrayRef>::new();
+
+            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), 
Some(-6)]));
+            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
+
+            val_with_orderings.push(vals);
+            val_with_orderings.push(orderings);
+
+            val_with_orderings
+        };
+
+        group_acc.update_batch(
+            &val_with_orderings,
+            &[0, 1, 2, 1],
+            Some(&BooleanArray::from(vec![true, true, false, true])),
+            3,
+        )?;
+        assert_eq!(
+            group_acc.size_of_orderings,
+            group_acc.compute_size_of_orderings()
+        );
+
+        let state = group_acc.state(EmitTo::All)?;
+
+        let expected_state: Vec<Arc<dyn Array>> = vec![
+            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+            Arc::new(BooleanArray::from(vec![true, true, false])),
+        ];
+        assert_eq!(state, expected_state);
+
+        assert_eq!(
+            group_acc.size_of_orderings,
+            group_acc.compute_size_of_orderings()
+        );
+
+        group_acc.merge_batch(
+            &state,
+            &[0, 1, 2],
+            Some(&BooleanArray::from(vec![true, false, false])),
+            3,
+        )?;
+
+        assert_eq!(
+            group_acc.size_of_orderings,
+            group_acc.compute_size_of_orderings()
+        );
+
+        val_with_orderings.clear();
+        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
+        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
+
+        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
+
+        let binding = group_acc.evaluate(EmitTo::All)?;
+        let eval_result = 
binding.as_any().downcast_ref::<Int64Array>().unwrap();
+
+        let expect: PrimitiveArray<Int64Type> =
+            Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
+
+        assert_eq!(eval_result, &expect);
+
+        assert_eq!(
+            group_acc.size_of_orderings,
+            group_acc.compute_size_of_orderings()
+        );
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_frist_group_acc_size_of_ordering() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int64, true),
+            Field::new("b", DataType::Int64, true),
+            Field::new("c", DataType::Int64, true),
+            Field::new("d", DataType::Int32, true),
+            Field::new("e", DataType::Boolean, true),
+        ]));
+
+        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
+            expr: col("c", &schema).unwrap(),
+            options: SortOptions::default(),
+        }]);
+
+        let mut group_acc = 
FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
+            sort_key,
+            true,
+            &DataType::Int64,
+            &[DataType::Int64],
+        )?;
+
+        let val_with_orderings = {
+            let mut val_with_orderings = Vec::<ArrayRef>::new();
+
+            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), 
Some(-6)]));
+            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
+
+            val_with_orderings.push(vals);
+            val_with_orderings.push(orderings);
+
+            val_with_orderings
+        };
+
+        for _ in 0..10 {
+            group_acc.update_batch(
+                &val_with_orderings,
+                &[0, 1, 2, 1],
+                Some(&BooleanArray::from(vec![true, true, false, true])),
+                100,
+            )?;
+            assert_eq!(
+                group_acc.size_of_orderings,
+                group_acc.compute_size_of_orderings()
+            );
+
+            group_acc.state(EmitTo::First(2))?;
+            assert_eq!(
+                group_acc.size_of_orderings,
+                group_acc.compute_size_of_orderings()
+            );
+
+            let s = group_acc.state(EmitTo::All)?;
+            assert_eq!(
+                group_acc.size_of_orderings,
+                group_acc.compute_size_of_orderings()
+            );
+
+            group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 
100)?;
+            assert_eq!(
+                group_acc.size_of_orderings,
+                group_acc.compute_size_of_orderings()
+            );
+
+            group_acc.evaluate(EmitTo::First(2))?;
+            assert_eq!(
+                group_acc.size_of_orderings,
+                group_acc.compute_size_of_orderings()
+            );
+
+            group_acc.evaluate(EmitTo::All)?;
+            assert_eq!(
+                group_acc.size_of_orderings,
+                group_acc.compute_size_of_orderings()
+            );
+        }
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/group_by.slt 
b/datafusion/sqllogictest/test_files/group_by.slt
index 5bf539e0b0..d9ef12496e 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -2665,6 +2665,47 @@ TUR [100.0, 75.0] 175
 # test_reverse_aggregate_expr
 # Some of the Aggregators can be reversed, by this way we can still run 
aggregators without re-ordering
 # that have contradictory requirements at first glance.
+
+statement ok
+CREATE TABLE null_group (
+          a INT, b INT, c INT, d INT
+) as VALUES
+  (6, 6, null, null),
+  (6, 6, 1, null),
+  (6, 6, null, 1)
+
+query III rowsort
+select c, d, first_value(a order by b) from null_group group by c, d;
+----
+1 NULL 6
+NULL 1 6
+NULL NULL 6
+
+
+
+statement ok
+CREATE TABLE first_null (
+          k INT,
+          val INT,
+          o int
+        ) as VALUES
+          (0, NULL, -9),
+          (0, 1, 1),
+          (1, 1, 1);
+
+query II rowsort
+select k, first_value(val order by o) IGNORE NULLS from first_null group by k;
+----
+0 1
+1 1
+
+query II rowsort
+select k, first_value(val order by o) respect NULLS from first_null group by k;
+----
+0 NULL
+1 1
+
+
 query TT
 EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts,
   FIRST_VALUE(amount ORDER BY amount ASC) AS fv1,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to