This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b22e4d2648 Improve performance of `last_value` by implementing special 
`GroupsAccumulator` (#15542)
b22e4d2648 is described below

commit b22e4d26480202247022bf2c94ba629a07e3275b
Author: UBarney <[email protected]>
AuthorDate: Thu Apr 10 22:49:06 2025 +0800

    Improve performance of `last_value` by implementing special 
`GroupsAccumulator` (#15542)
    
    * Improve performance of `last_value` by implementing special 
`GroupsAccumulator`
    
    * less diff
---
 datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs |  26 +++
 .../tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs  |   4 +-
 datafusion/functions-aggregate/src/first_last.rs   | 231 +++++++++++++++++++--
 datafusion/sqllogictest/test_files/group_by.slt    |  27 ++-
 4 files changed, 270 insertions(+), 18 deletions(-)

diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index 4ba63e9991..ff3b66986c 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -120,6 +120,32 @@ async fn test_first_val() {
         .await;
 }
 
+#[tokio::test(flavor = "multi_thread")]
+async fn test_last_val() {
+    let mut data_gen_config = baseline_config();
+
+    for i in 0..data_gen_config.columns.len() {
+        if data_gen_config.columns[i].get_max_num_distinct().is_none() {
+            data_gen_config.columns[i] = data_gen_config.columns[i]
+                .clone()
+                // Minimize the chance of identical values in the order by 
columns to make the test more stable
+                .with_max_num_distinct(usize::MAX);
+        }
+    }
+
+    let query_builder = QueryBuilder::new()
+        .with_table_name("fuzz_table")
+        .with_aggregate_function("last_value")
+        .with_aggregate_arguments(data_gen_config.all_columns())
+        .set_group_by_columns(data_gen_config.all_columns());
+
+    AggregationFuzzerBuilder::from(data_gen_config)
+        .add_query_builder(query_builder)
+        .build()
+        .run()
+        .await;
+}
+
 #[tokio::test(flavor = "multi_thread")]
 async fn test_max() {
     let data_gen_config = baseline_config();
diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs 
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
index 1e97ddcdb1..53e9288ab4 100644
--- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
@@ -503,7 +503,9 @@ impl QueryBuilder {
             let distinct = if *is_distinct { "DISTINCT " } else { "" };
             alias_gen += 1;
 
-            let (order_by, null_opt) = if function_name.eq("first_value") {
+            let (order_by, null_opt) = if function_name.eq("first_value")
+                || function_name.eq("last_value")
+            {
                 (
                     self.order_by(&order_by_black_list), /* Among the order by 
columns, at most one group by column can be included to avoid all order by 
column values being identical */
                     self.null_opt(),
diff --git a/datafusion/functions-aggregate/src/first_last.rs 
b/datafusion/functions-aggregate/src/first_last.rs
index 28e6a8723d..6465436375 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -166,6 +166,7 @@ impl AggregateUDFImpl for FirstValue {
     }
 
     fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        // TODO: extract to function
         use DataType::*;
         matches!(
             args.return_type,
@@ -193,6 +194,7 @@ impl AggregateUDFImpl for FirstValue {
         &self,
         args: AccumulatorArgs,
     ) -> Result<Box<dyn GroupsAccumulator>> {
+        // TODO: extract to function
         fn create_accumulator<T>(
             args: AccumulatorArgs,
         ) -> Result<Box<dyn GroupsAccumulator>>
@@ -210,6 +212,7 @@ impl AggregateUDFImpl for FirstValue {
                 args.ignore_nulls,
                 args.return_type,
                 &ordering_dtypes,
+                true,
             )?))
         }
 
@@ -258,10 +261,12 @@ impl AggregateUDFImpl for FirstValue {
                 create_accumulator::<Time64NanosecondType>(args)
             }
 
-            _ => internal_err!(
-                "GroupsAccumulator not supported for first({})",
-                args.return_type
-            ),
+            _ => {
+                internal_err!(
+                    "GroupsAccumulator not supported for first_value({})",
+                    args.return_type
+                )
+            }
         }
     }
 
@@ -291,6 +296,7 @@ impl AggregateUDFImpl for FirstValue {
     }
 }
 
+// TODO: rename to PrimitiveGroupsAccumulator
 struct FirstPrimitiveGroupsAccumulator<T>
 where
     T: ArrowPrimitiveType + Send,
@@ -316,12 +322,16 @@ where
     // buffer for `get_filtered_min_of_each_group`
     // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
     // only valid if filter_min_of_each_group_buf.1[group_idx] == true
+    // TODO: rename to extreme_of_each_group_buf
     min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
 
     // =========== option ============
 
     // Stores the applicable ordering requirement.
     ordering_req: LexOrdering,
+    // true: take first element in an aggregation group according to the 
requested ordering.
+    // false: take last element in an aggregation group according to the 
requested ordering.
+    pick_first_in_group: bool,
     // derived from `ordering_req`.
     sort_options: Vec<SortOptions>,
     // Stores whether incoming data already satisfies the ordering requirement.
@@ -342,6 +352,7 @@ where
         ignore_nulls: bool,
         data_type: &DataType,
         ordering_dtypes: &[DataType],
+        pick_first_in_group: bool,
     ) -> Result<Self> {
         let requirement_satisfied = ordering_req.is_empty();
 
@@ -365,6 +376,7 @@ where
             is_sets: BooleanBufferBuilder::new(0),
             size_of_orderings: 0,
             min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
+            pick_first_in_group,
         })
     }
 
@@ -391,8 +403,13 @@ where
 
         assert!(new_ordering_values.len() == self.ordering_req.len());
         let current_ordering = &self.orderings[group_idx];
-        compare_rows(current_ordering, new_ordering_values, &self.sort_options)
-            .map(|x| x.is_gt())
+        compare_rows(current_ordering, new_ordering_values, 
&self.sort_options).map(|x| {
+            if self.pick_first_in_group {
+                x.is_gt()
+            } else {
+                x.is_lt()
+            }
+        })
     }
 
     fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
@@ -501,10 +518,10 @@ where
             .map(ScalarValue::size_of_vec)
             .sum::<usize>()
     }
-
     /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the 
index of the
     /// minimum value in `orderings` for each group, using lexicographical 
comparison.
     /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
+    /// TODO: rename to get_filtered_extreme_of_each_group
     fn get_filtered_min_of_each_group(
         &mut self,
         orderings: &[ArrayRef],
@@ -556,15 +573,19 @@ where
             }
 
             let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
-            if is_valid
-                && comparator
-                    .compare(self.min_of_each_group_buf.0[group_idx], 
idx_in_val)
-                    .is_gt()
-            {
-                self.min_of_each_group_buf.0[group_idx] = idx_in_val;
-            } else if !is_valid {
+
+            if !is_valid {
                 self.min_of_each_group_buf.1.set_bit(group_idx, true);
                 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+            } else {
+                let ordering = comparator
+                    .compare(self.min_of_each_group_buf.0[group_idx], 
idx_in_val);
+
+                if (ordering.is_gt() && self.pick_first_in_group)
+                    || (ordering.is_lt() && !self.pick_first_in_group)
+                {
+                    self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+                }
             }
         }
 
@@ -1052,6 +1073,109 @@ impl AggregateUDFImpl for LastValue {
     fn documentation(&self) -> Option<&Documentation> {
         self.doc()
     }
+
+    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        use DataType::*;
+        matches!(
+            args.return_type,
+            Int8 | Int16
+                | Int32
+                | Int64
+                | UInt8
+                | UInt16
+                | UInt32
+                | UInt64
+                | Float16
+                | Float32
+                | Float64
+                | Decimal128(_, _)
+                | Decimal256(_, _)
+                | Date32
+                | Date64
+                | Time32(_)
+                | Time64(_)
+                | Timestamp(_, _)
+        )
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        fn create_accumulator<T>(
+            args: AccumulatorArgs,
+        ) -> Result<Box<dyn GroupsAccumulator>>
+        where
+            T: ArrowPrimitiveType + Send,
+        {
+            let ordering_dtypes = args
+                .ordering_req
+                .iter()
+                .map(|e| e.expr.data_type(args.schema))
+                .collect::<Result<Vec<_>>>()?;
+
+            Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
+                args.ordering_req.clone(),
+                args.ignore_nulls,
+                args.return_type,
+                &ordering_dtypes,
+                false,
+            )?))
+        }
+
+        match args.return_type {
+            DataType::Int8 => create_accumulator::<Int8Type>(args),
+            DataType::Int16 => create_accumulator::<Int16Type>(args),
+            DataType::Int32 => create_accumulator::<Int32Type>(args),
+            DataType::Int64 => create_accumulator::<Int64Type>(args),
+            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
+            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
+            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
+            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
+            DataType::Float16 => create_accumulator::<Float16Type>(args),
+            DataType::Float32 => create_accumulator::<Float32Type>(args),
+            DataType::Float64 => create_accumulator::<Float64Type>(args),
+
+            DataType::Decimal128(_, _) => 
create_accumulator::<Decimal128Type>(args),
+            DataType::Decimal256(_, _) => 
create_accumulator::<Decimal256Type>(args),
+
+            DataType::Timestamp(TimeUnit::Second, _) => {
+                create_accumulator::<TimestampSecondType>(args)
+            }
+            DataType::Timestamp(TimeUnit::Millisecond, _) => {
+                create_accumulator::<TimestampMillisecondType>(args)
+            }
+            DataType::Timestamp(TimeUnit::Microsecond, _) => {
+                create_accumulator::<TimestampMicrosecondType>(args)
+            }
+            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+                create_accumulator::<TimestampNanosecondType>(args)
+            }
+
+            DataType::Date32 => create_accumulator::<Date32Type>(args),
+            DataType::Date64 => create_accumulator::<Date64Type>(args),
+            DataType::Time32(TimeUnit::Second) => {
+                create_accumulator::<Time32SecondType>(args)
+            }
+            DataType::Time32(TimeUnit::Millisecond) => {
+                create_accumulator::<Time32MillisecondType>(args)
+            }
+
+            DataType::Time64(TimeUnit::Microsecond) => {
+                create_accumulator::<Time64MicrosecondType>(args)
+            }
+            DataType::Time64(TimeUnit::Nanosecond) => {
+                create_accumulator::<Time64NanosecondType>(args)
+            }
+
+            _ => {
+                internal_err!(
+                    "GroupsAccumulator not supported for last_value({})",
+                    args.return_type
+                )
+            }
+        }
+    }
 }
 
 #[derive(Debug)]
@@ -1411,6 +1535,7 @@ mod tests {
             true,
             &DataType::Int64,
             &[DataType::Int64],
+            true,
         )?;
 
         let mut val_with_orderings = {
@@ -1485,7 +1610,7 @@ mod tests {
     }
 
     #[test]
-    fn test_frist_group_acc_size_of_ordering() -> Result<()> {
+    fn test_group_acc_size_of_ordering() -> Result<()> {
         let schema = Arc::new(Schema::new(vec![
             Field::new("a", DataType::Int64, true),
             Field::new("b", DataType::Int64, true),
@@ -1504,6 +1629,7 @@ mod tests {
             true,
             &DataType::Int64,
             &[DataType::Int64],
+            true,
         )?;
 
         let val_with_orderings = {
@@ -1563,4 +1689,79 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_last_group_acc() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int64, true),
+            Field::new("b", DataType::Int64, true),
+            Field::new("c", DataType::Int64, true),
+            Field::new("d", DataType::Int32, true),
+            Field::new("e", DataType::Boolean, true),
+        ]));
+
+        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
+            expr: col("c", &schema).unwrap(),
+            options: SortOptions::default(),
+        }]);
+
+        let mut group_acc = 
FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
+            sort_key,
+            true,
+            &DataType::Int64,
+            &[DataType::Int64],
+            false,
+        )?;
+
+        let mut val_with_orderings = {
+            let mut val_with_orderings = Vec::<ArrayRef>::new();
+
+            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), 
Some(-6)]));
+            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
+
+            val_with_orderings.push(vals);
+            val_with_orderings.push(orderings);
+
+            val_with_orderings
+        };
+
+        group_acc.update_batch(
+            &val_with_orderings,
+            &[0, 1, 2, 1],
+            Some(&BooleanArray::from(vec![true, true, false, true])),
+            3,
+        )?;
+
+        let state = group_acc.state(EmitTo::All)?;
+
+        let expected_state: Vec<Arc<dyn Array>> = vec![
+            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+            Arc::new(BooleanArray::from(vec![true, true, false])),
+        ];
+        assert_eq!(state, expected_state);
+
+        group_acc.merge_batch(
+            &state,
+            &[0, 1, 2],
+            Some(&BooleanArray::from(vec![true, false, false])),
+            3,
+        )?;
+
+        val_with_orderings.clear();
+        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
+        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
+
+        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
+
+        let binding = group_acc.evaluate(EmitTo::All)?;
+        let eval_result = 
binding.as_any().downcast_ref::<Int64Array>().unwrap();
+
+        let expect: PrimitiveArray<Int64Type> =
+            Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
+
+        assert_eq!(eval_result, &expect);
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/group_by.slt 
b/datafusion/sqllogictest/test_files/group_by.slt
index 4c4999a364..9e67018ecd 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -2232,7 +2232,7 @@ physical_plan
 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c], 
infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, 
c@2 ASC NULLS LAST]
 
 query III
-SELECT a, b, LAST_VALUE(c) as last_c
+SELECT a, b, LAST_VALUE(c order by c) as last_c
   FROM annotated_data_infinite2
   GROUP BY a, b
 ----
@@ -2706,6 +2706,29 @@ select k, first_value(val order by o) respect NULLS from 
first_null group by k;
 1 1
 
 
+statement ok
+CREATE TABLE last_null (
+          k INT,
+          val INT,
+          o int
+        ) as VALUES
+          (0, NULL, 9),
+          (0, 1, 1),
+          (1, 1, 1);
+
+query II rowsort
+select k, last_value(val order by o) IGNORE NULLS from last_null group by k;
+----
+0 1
+1 1
+
+query II rowsort
+select k, last_value(val order by o) respect NULLS from last_null group by k;
+----
+0 NULL
+1 1
+
+
 query TT
 EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts,
   FIRST_VALUE(amount ORDER BY amount ASC) AS fv1,
@@ -3775,7 +3798,7 @@ ORDER BY x;
 2 2
 
 query II
-SELECT y, LAST_VALUE(x)
+SELECT y, LAST_VALUE(x order by x desc)
 FROM FOO
 GROUP BY y
 ORDER BY y;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to