This is an automated email from the ASF dual-hosted git repository.
ytyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 923bfb7fc7 Improve performance of `first_value` by implementing
special `GroupsAccumulator` (#15266)
923bfb7fc7 is described below
commit 923bfb7fc7cf1718522f572b74f3756d02652933
Author: UBarney <[email protected]>
AuthorDate: Wed Mar 26 10:56:36 2025 +0800
Improve performance of `first_value` by implementing special
`GroupsAccumulator` (#15266)
* Improve speed of first_value by implementing special GroupsAccumulator
* rename and other improvements
* `append_n` -> `resize`
* address comment
* use HashMap::entry
* remove hashMap in get_filtered_min_of_each_group
---
datafusion/common/src/utils/mod.rs | 17 +
datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 27 +-
.../aggregation_fuzzer/data_generator.rs | 4 +
.../tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs | 66 +-
datafusion/functions-aggregate/src/first_last.rs | 754 ++++++++++++++++++++-
datafusion/sqllogictest/test_files/group_by.slt | 41 ++
6 files changed, 897 insertions(+), 12 deletions(-)
diff --git a/datafusion/common/src/utils/mod.rs
b/datafusion/common/src/utils/mod.rs
index ff9cdedab8..409f248621 100644
--- a/datafusion/common/src/utils/mod.rs
+++ b/datafusion/common/src/utils/mod.rs
@@ -82,6 +82,23 @@ pub fn project_schema(
Ok(schema)
}
+/// Extracts a row at the specified index from a set of columns and stores it
in the provided buffer.
+pub fn extract_row_at_idx_to_buf(
+ columns: &[ArrayRef],
+ idx: usize,
+ buf: &mut Vec<ScalarValue>,
+) -> Result<()> {
+ buf.clear();
+
+ let iter = columns
+ .iter()
+ .map(|arr| ScalarValue::try_from_array(arr, idx));
+ for v in iter.into_iter() {
+ buf.push(v?);
+ }
+
+ Ok(())
+}
/// Given column vectors, returns row at `idx`.
pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) ->
Result<Vec<ScalarValue>> {
columns
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index 46221acfcc..1b98a19581 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-use std::str;
use std::sync::Arc;
use crate::fuzz_cases::aggregation_fuzzer::{
@@ -88,6 +87,32 @@ async fn test_min() {
.await;
}
+#[tokio::test(flavor = "multi_thread")]
+async fn test_first_val() {
+ let mut data_gen_config: DatasetGeneratorConfig = baseline_config();
+
+ for i in 0..data_gen_config.columns.len() {
+ if data_gen_config.columns[i].get_max_num_distinct().is_none() {
+ data_gen_config.columns[i] = data_gen_config.columns[i]
+ .clone()
+ // Minimize the chance of identical values in the order by
columns to make the test more stable
+ .with_max_num_distinct(usize::MAX);
+ }
+ }
+
+ let query_builder = QueryBuilder::new()
+ .with_table_name("fuzz_table")
+ .with_aggregate_function("first_value")
+ .with_aggregate_arguments(data_gen_config.all_columns())
+ .set_group_by_columns(data_gen_config.all_columns());
+
+ AggregationFuzzerBuilder::from(data_gen_config)
+ .add_query_builder(query_builder)
+ .build()
+ .run()
+ .await;
+}
+
#[tokio::test(flavor = "multi_thread")]
async fn test_max() {
let data_gen_config = baseline_config();
diff --git
a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
index 54c5744c86..d61835a080 100644
--- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
@@ -228,6 +228,10 @@ impl ColumnDescr {
}
}
+ pub fn get_max_num_distinct(&self) -> Option<usize> {
+ self.max_num_distinct
+ }
+
/// set the maximum number of distinct values in this column
///
/// If `None`, the number of distinct values is randomly selected between 1
diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
index c608adda5d..bb24fb554d 100644
--- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
@@ -15,13 +15,14 @@
// specific language governing permissions and limitations
// under the License.
-use std::collections::HashSet;
use std::sync::Arc;
+use std::{collections::HashSet, str::FromStr};
use arrow::array::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use datafusion_common::{DataFusionError, Result};
use datafusion_common_runtime::JoinSet;
+use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use crate::fuzz_cases::aggregation_fuzzer::{
@@ -452,7 +453,11 @@ impl QueryBuilder {
pub fn generate_query(&self) -> String {
let group_by = self.random_group_by();
let mut query = String::from("SELECT ");
- query.push_str(&self.random_aggregate_functions().join(", "));
+ query.push_str(&group_by.join(", "));
+ if !group_by.is_empty() {
+ query.push_str(", ");
+ }
+ query.push_str(&self.random_aggregate_functions(&group_by).join(", "));
query.push_str(" FROM ");
query.push_str(&self.table_name);
if !group_by.is_empty() {
@@ -474,7 +479,7 @@ impl QueryBuilder {
/// * `function_names` are randomly selected from
[`Self::aggregate_functions`]
/// * `<DISTINCT> argument` is randomly selected from [`Self::arguments`]
/// * `alias` is a unique alias `colN` for the column (to avoid duplicate
column names)
- fn random_aggregate_functions(&self) -> Vec<String> {
+ fn random_aggregate_functions(&self, group_by_cols: &[String]) ->
Vec<String> {
const MAX_NUM_FUNCTIONS: usize = 5;
let mut rng = thread_rng();
let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS);
@@ -482,6 +487,14 @@ impl QueryBuilder {
let mut alias_gen = 1;
let mut aggregate_functions = vec![];
+
+ let mut order_by_black_list: HashSet<String> =
+ group_by_cols.iter().cloned().collect();
+ // remove one random col
+ if let Some(first) = order_by_black_list.iter().next().cloned() {
+ order_by_black_list.remove(&first);
+ }
+
while aggregate_functions.len() < num_aggregate_functions {
let idx = rng.gen_range(0..self.aggregate_functions.len());
let (function_name, is_distinct) = &self.aggregate_functions[idx];
@@ -489,7 +502,19 @@ impl QueryBuilder {
let alias = format!("col{}", alias_gen);
let distinct = if *is_distinct { "DISTINCT " } else { "" };
alias_gen += 1;
- let function = format!("{function_name}({distinct}{argument}) as
{alias}");
+
+ let (order_by, null_opt) = if function_name.eq("first_value") {
+ (
+ self.order_by(&order_by_black_list), /* Among the order by
columns, at most one group by column can be included to avoid all order by
column values being identical */
+ self.null_opt(),
+ )
+ } else {
+ ("".to_string(), "".to_string())
+ };
+
+ let function = format!(
+ "{function_name}({distinct}{argument}{order_by}) {null_opt} as
{alias}"
+ );
aggregate_functions.push(function);
}
aggregate_functions
@@ -502,6 +527,39 @@ impl QueryBuilder {
self.arguments[idx].clone()
}
+ fn order_by(&self, black_list: &HashSet<String>) -> String {
+ let mut available_columns: Vec<String> = self
+ .arguments
+ .iter()
+ .filter(|col| !black_list.contains(*col))
+ .cloned()
+ .collect();
+
+ available_columns.shuffle(&mut thread_rng());
+
+ let num_of_order_by_col = 12;
+ let column_count = std::cmp::min(num_of_order_by_col,
available_columns.len());
+
+ let selected_columns = &available_columns[0..column_count];
+
+ let mut rng = thread_rng();
+ let mut result = String::from_str(" order by ").unwrap();
+ for col in selected_columns {
+ let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" };
+ result.push_str(&format!("{} {},", col, order));
+ }
+
+ result.strip_suffix(",").unwrap().to_string()
+ }
+
+ fn null_opt(&self) -> String {
+ if thread_rng().gen_bool(0.5) {
+ "RESPECT NULLS".to_string()
+ } else {
+ "IGNORE NULLS".to_string()
+ }
+ }
+
/// Pick a random number of fields to group by (non-repeating)
///
/// Limited to 3 group by columns to ensure coverage for large groups. With
diff --git a/datafusion/functions-aggregate/src/first_last.rs
b/datafusion/functions-aggregate/src/first_last.rs
index 6df8ede4fc..28e6a8723d 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -22,18 +22,30 @@ use std::fmt::Debug;
use std::mem::size_of_val;
use std::sync::Arc;
-use arrow::array::{ArrayRef, AsArray, BooleanArray};
-use arrow::compute::{self, LexicographicalComparator, SortColumn};
-use arrow::datatypes::{DataType, Field};
-use datafusion_common::utils::{compare_rows, get_row_at_idx};
+use arrow::array::{
+ Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray,
BooleanBufferBuilder,
+ PrimitiveArray,
+};
+use arrow::buffer::{BooleanBuffer, NullBuffer};
+use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
+use arrow::datatypes::{
+ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
Float16Type,
+ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
+ Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType,
+ TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
+ TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type,
UInt64Type,
+ UInt8Type,
+};
+use datafusion_common::cast::as_boolean_array;
+use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf,
get_row_at_idx};
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
use datafusion_expr::{
- Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt,
Signature,
- SortExpr, Volatility,
+ Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr,
ExprFunctionExt,
+ GroupsAccumulator, Signature, SortExpr, Volatility,
};
use datafusion_functions_aggregate_common::utils::get_sort_options;
use datafusion_macros::user_doc;
@@ -153,6 +165,106 @@ impl AggregateUDFImpl for FirstValue {
Ok(fields)
}
+ fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ use DataType::*;
+ matches!(
+ args.return_type,
+ Int8 | Int16
+ | Int32
+ | Int64
+ | UInt8
+ | UInt16
+ | UInt32
+ | UInt64
+ | Float16
+ | Float32
+ | Float64
+ | Decimal128(_, _)
+ | Decimal256(_, _)
+ | Date32
+ | Date64
+ | Time32(_)
+ | Time64(_)
+ | Timestamp(_, _)
+ )
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ fn create_accumulator<T>(
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>>
+ where
+ T: ArrowPrimitiveType + Send,
+ {
+ let ordering_dtypes = args
+ .ordering_req
+ .iter()
+ .map(|e| e.expr.data_type(args.schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
+ args.ordering_req.clone(),
+ args.ignore_nulls,
+ args.return_type,
+ &ordering_dtypes,
+ )?))
+ }
+
+ match args.return_type {
+ DataType::Int8 => create_accumulator::<Int8Type>(args),
+ DataType::Int16 => create_accumulator::<Int16Type>(args),
+ DataType::Int32 => create_accumulator::<Int32Type>(args),
+ DataType::Int64 => create_accumulator::<Int64Type>(args),
+ DataType::UInt8 => create_accumulator::<UInt8Type>(args),
+ DataType::UInt16 => create_accumulator::<UInt16Type>(args),
+ DataType::UInt32 => create_accumulator::<UInt32Type>(args),
+ DataType::UInt64 => create_accumulator::<UInt64Type>(args),
+ DataType::Float16 => create_accumulator::<Float16Type>(args),
+ DataType::Float32 => create_accumulator::<Float32Type>(args),
+ DataType::Float64 => create_accumulator::<Float64Type>(args),
+
+ DataType::Decimal128(_, _) =>
create_accumulator::<Decimal128Type>(args),
+ DataType::Decimal256(_, _) =>
create_accumulator::<Decimal256Type>(args),
+
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ create_accumulator::<TimestampSecondType>(args)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ create_accumulator::<TimestampMillisecondType>(args)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ create_accumulator::<TimestampMicrosecondType>(args)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ create_accumulator::<TimestampNanosecondType>(args)
+ }
+
+ DataType::Date32 => create_accumulator::<Date32Type>(args),
+ DataType::Date64 => create_accumulator::<Date64Type>(args),
+ DataType::Time32(TimeUnit::Second) => {
+ create_accumulator::<Time32SecondType>(args)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ create_accumulator::<Time32MillisecondType>(args)
+ }
+
+ DataType::Time64(TimeUnit::Microsecond) => {
+ create_accumulator::<Time64MicrosecondType>(args)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ create_accumulator::<Time64NanosecondType>(args)
+ }
+
+ _ => internal_err!(
+ "GroupsAccumulator not supported for first({})",
+ args.return_type
+ ),
+ }
+ }
+
fn aliases(&self) -> &[String] {
&[]
}
@@ -179,6 +291,460 @@ impl AggregateUDFImpl for FirstValue {
}
}
+struct FirstPrimitiveGroupsAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+{
+ // ================ state ===========
+ vals: Vec<T::Native>,
+ // Stores ordering values, of the aggregator requirement corresponding to
first value
+ // of the aggregator.
+ // The `orderings` are stored row-wise, meaning that `orderings[group_idx]`
+ // represents the ordering values corresponding to the `group_idx`-th
group.
+ orderings: Vec<Vec<ScalarValue>>,
+ // At the beginning, `is_sets[group_idx]` is false, which means `first` is
not seen yet.
+ // Once we see the first value, we set the `is_sets[group_idx]` flag
+ is_sets: BooleanBufferBuilder,
+ // null_builder[group_idx] == false => vals[group_idx] is null
+ null_builder: BooleanBufferBuilder,
+ // size of `self.orderings`
+ // Calculating the memory usage of `self.orderings` using
`ScalarValue::size_of_vec` is quite costly.
+ // Therefore, we cache it and compute `size_of` only after each update
+ // to avoid calling `ScalarValue::size_of_vec` by Self.size.
+ size_of_orderings: usize,
+
+ // buffer for `get_filtered_min_of_each_group`
+ // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
+ // only valid if filter_min_of_each_group_buf.1[group_idx] == true
+ min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
+
+ // =========== option ============
+
+ // Stores the applicable ordering requirement.
+ ordering_req: LexOrdering,
+ // derived from `ordering_req`.
+ sort_options: Vec<SortOptions>,
+ // Stores whether incoming data already satisfies the ordering requirement.
+ input_requirement_satisfied: bool,
+ // Ignore null values.
+ ignore_nulls: bool,
+ /// The output type
+ data_type: DataType,
+ default_orderings: Vec<ScalarValue>,
+}
+
+impl<T> FirstPrimitiveGroupsAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+{
+ fn try_new(
+ ordering_req: LexOrdering,
+ ignore_nulls: bool,
+ data_type: &DataType,
+ ordering_dtypes: &[DataType],
+ ) -> Result<Self> {
+ let requirement_satisfied = ordering_req.is_empty();
+
+ let default_orderings = ordering_dtypes
+ .iter()
+ .map(ScalarValue::try_from)
+ .collect::<Result<Vec<_>>>()?;
+
+ let sort_options = get_sort_options(ordering_req.as_ref());
+
+ Ok(Self {
+ null_builder: BooleanBufferBuilder::new(0),
+ ordering_req,
+ sort_options,
+ input_requirement_satisfied: requirement_satisfied,
+ ignore_nulls,
+ default_orderings,
+ data_type: data_type.clone(),
+ vals: Vec::new(),
+ orderings: Vec::new(),
+ is_sets: BooleanBufferBuilder::new(0),
+ size_of_orderings: 0,
+ min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
+ })
+ }
+
+ fn need_update(&self, group_idx: usize) -> bool {
+ if !self.is_sets.get_bit(group_idx) {
+ return true;
+ }
+
+ if self.ignore_nulls && !self.null_builder.get_bit(group_idx) {
+ return true;
+ }
+
+ !self.input_requirement_satisfied
+ }
+
+ fn should_update_state(
+ &self,
+ group_idx: usize,
+ new_ordering_values: &[ScalarValue],
+ ) -> Result<bool> {
+ if !self.is_sets.get_bit(group_idx) {
+ return Ok(true);
+ }
+
+ assert!(new_ordering_values.len() == self.ordering_req.len());
+ let current_ordering = &self.orderings[group_idx];
+ compare_rows(current_ordering, new_ordering_values, &self.sort_options)
+ .map(|x| x.is_gt())
+ }
+
+ fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
+ let result = emit_to.take_needed(&mut self.orderings);
+
+ match emit_to {
+ EmitTo::All => self.size_of_orderings = 0,
+ EmitTo::First(_) => {
+ self.size_of_orderings -=
+ result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
+ }
+ }
+
+ result
+ }
+
+ fn take_need(
+ bool_buf_builder: &mut BooleanBufferBuilder,
+ emit_to: EmitTo,
+ ) -> BooleanBuffer {
+ let bool_buf = bool_buf_builder.finish();
+ match emit_to {
+ EmitTo::All => bool_buf,
+ EmitTo::First(n) => {
+ // split off the first N values in seen_values
+ //
+ // TODO make this more efficient rather than two
+ // copies and bitwise manipulation
+ let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
+ // reset the existing buffer
+ for b in bool_buf.iter().skip(n) {
+ bool_buf_builder.append(b);
+ }
+ first_n
+ }
+ }
+ }
+
+ fn resize_states(&mut self, new_size: usize) {
+ self.vals.resize(new_size, T::default_value());
+
+ self.null_builder.resize(new_size);
+
+ if self.orderings.len() < new_size {
+ let current_len = self.orderings.len();
+
+ self.orderings
+ .resize(new_size, self.default_orderings.clone());
+
+ self.size_of_orderings += (new_size - current_len)
+ * ScalarValue::size_of_vec(
+ // Note: In some cases (such as in the unit test below)
+ // ScalarValue::size_of_vec(&self.default_orderings) !=
ScalarValue::size_of_vec(&self.default_orderings.clone())
+ // This may be caused by the different vec.capacity()
values?
+ self.orderings.last().unwrap(),
+ );
+ }
+
+ self.is_sets.resize(new_size);
+
+ self.min_of_each_group_buf.0.resize(new_size, 0);
+ self.min_of_each_group_buf.1.resize(new_size);
+ }
+
+ fn update_state(
+ &mut self,
+ group_idx: usize,
+ orderings: &[ScalarValue],
+ new_val: T::Native,
+ is_null: bool,
+ ) {
+ self.vals[group_idx] = new_val;
+ self.is_sets.set_bit(group_idx, true);
+
+ self.null_builder.set_bit(group_idx, !is_null);
+
+ assert!(orderings.len() == self.ordering_req.len());
+ let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
+ self.orderings[group_idx].clear();
+ self.orderings[group_idx].extend_from_slice(orderings);
+ let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
+ self.size_of_orderings = self.size_of_orderings - old_size + new_size;
+ }
+
+ fn take_state(
+ &mut self,
+ emit_to: EmitTo,
+ ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
+ emit_to.take_needed(&mut self.min_of_each_group_buf.0);
+ self.min_of_each_group_buf
+ .1
+ .truncate(self.min_of_each_group_buf.0.len());
+
+ (
+ self.take_vals_and_null_buf(emit_to),
+ self.take_orderings(emit_to),
+ Self::take_need(&mut self.is_sets, emit_to),
+ )
+ }
+
+ // should be used in test only
+ #[cfg(test)]
+ fn compute_size_of_orderings(&self) -> usize {
+ self.orderings
+ .iter()
+ .map(ScalarValue::size_of_vec)
+ .sum::<usize>()
+ }
+
+ /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the
index of the
+ /// minimum value in `orderings` for each group, using lexicographical
comparison.
+ /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
+ fn get_filtered_min_of_each_group(
+ &mut self,
+ orderings: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ vals: &PrimitiveArray<T>,
+ is_set_arr: Option<&BooleanArray>,
+ ) -> Result<Vec<(usize, usize)>> {
+ // Set all values in min_of_each_group_buf.1 to false.
+ self.min_of_each_group_buf.1.truncate(0);
+ self.min_of_each_group_buf
+ .1
+ .append_n(self.vals.len(), false);
+
+ // No need to call `clear` since
`self.min_of_each_group_buf.0[group_idx]`
+ // is only valid when `self.min_of_each_group_buf.1[group_idx] ==
true`.
+
+ let comparator = {
+ assert_eq!(orderings.len(), self.ordering_req.len());
+ let sort_columns = orderings
+ .iter()
+ .zip(self.ordering_req.iter())
+ .map(|(array, req)| SortColumn {
+ values: Arc::clone(array),
+ options: Some(req.options),
+ })
+ .collect::<Vec<_>>();
+
+ LexicographicalComparator::try_new(&sort_columns)?
+ };
+
+ for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
+ let group_idx = *group_idx;
+
+ let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
+
+ let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
+
+ if !passed_filter || !is_set {
+ continue;
+ }
+
+ if !self.need_update(group_idx) {
+ continue;
+ }
+
+ if self.ignore_nulls && vals.is_null(idx_in_val) {
+ continue;
+ }
+
+ let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
+ if is_valid
+ && comparator
+ .compare(self.min_of_each_group_buf.0[group_idx],
idx_in_val)
+ .is_gt()
+ {
+ self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+ } else if !is_valid {
+ self.min_of_each_group_buf.1.set_bit(group_idx, true);
+ self.min_of_each_group_buf.0[group_idx] = idx_in_val;
+ }
+ }
+
+ Ok(self
+ .min_of_each_group_buf
+ .0
+ .iter()
+ .enumerate()
+ .filter(|(group_idx, _)|
self.min_of_each_group_buf.1.get_bit(*group_idx))
+ .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
+ .collect::<Vec<_>>())
+ }
+
+ fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
+ let r = emit_to.take_needed(&mut self.vals);
+
+ let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder,
emit_to));
+
+ let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) // no
copy
+ .with_data_type(self.data_type.clone());
+ Arc::new(values)
+ }
+}
+
+impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+{
+ fn update_batch(
+ &mut self,
+ // e.g. first_value(a order by b): values_and_order_cols will be [a, b]
+ values_and_order_cols: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ self.resize_states(total_num_groups);
+
+ let vals = values_and_order_cols[0].as_primitive::<T>();
+
+ let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
+
+ // The overhead of calling `extract_row_at_idx_to_buf` is somewhat
high, so we need to minimize its calls as much as possible.
+ for (group_idx, idx) in self
+ .get_filtered_min_of_each_group(
+ &values_and_order_cols[1..],
+ group_indices,
+ opt_filter,
+ vals,
+ None,
+ )?
+ .into_iter()
+ {
+ extract_row_at_idx_to_buf(
+ &values_and_order_cols[1..],
+ idx,
+ &mut ordering_buf,
+ )?;
+
+ if self.should_update_state(group_idx, &ordering_buf)? {
+ self.update_state(
+ group_idx,
+ &ordering_buf,
+ vals.value(idx),
+ vals.is_null(idx),
+ );
+ }
+ }
+
+ Ok(())
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+ Ok(self.take_state(emit_to).0)
+ }
+
+ fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+ let (val_arr, orderings, is_sets) = self.take_state(emit_to);
+ let mut result = Vec::with_capacity(self.orderings.len() + 2);
+
+ result.push(val_arr);
+
+ let ordering_cols = {
+ let mut ordering_cols =
Vec::with_capacity(self.ordering_req.len());
+ for _ in 0..self.ordering_req.len() {
+ ordering_cols.push(Vec::with_capacity(self.orderings.len()));
+ }
+ for row in orderings.into_iter() {
+ assert_eq!(row.len(), self.ordering_req.len());
+ for (col_idx, ordering) in row.into_iter().enumerate() {
+ ordering_cols[col_idx].push(ordering);
+ }
+ }
+
+ ordering_cols
+ };
+ for ordering_col in ordering_cols {
+ result.push(ScalarValue::iter_to_array(ordering_col)?);
+ }
+
+ result.push(Arc::new(BooleanArray::new(is_sets, None)));
+
+ Ok(result)
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ self.resize_states(total_num_groups);
+
+ let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
+
+ let (is_set_arr, val_and_order_cols) = match values.split_last() {
+ Some(result) => result,
+ None => return internal_err!("Empty row in FISRT_VALUE"),
+ };
+
+ let is_set_arr = as_boolean_array(is_set_arr)?;
+
+ let vals = values[0].as_primitive::<T>();
+ // The overhead of calling `extract_row_at_idx_to_buf` is somewhat
high, so we need to minimize its calls as much as possible.
+ let groups = self.get_filtered_min_of_each_group(
+ &val_and_order_cols[1..],
+ group_indices,
+ opt_filter,
+ vals,
+ Some(is_set_arr),
+ )?;
+
+ for (group_idx, idx) in groups.into_iter() {
+ extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut
ordering_buf)?;
+
+ if self.should_update_state(group_idx, &ordering_buf)? {
+ self.update_state(
+ group_idx,
+ &ordering_buf,
+ vals.value(idx),
+ vals.is_null(idx),
+ );
+ }
+ }
+
+ Ok(())
+ }
+
+ fn size(&self) -> usize {
+ self.vals.capacity() * size_of::<T::Native>()
+ + self.null_builder.capacity() / 8 // capacity is in bits, so
convert to bytes
+ + self.is_sets.capacity() / 8
+ + self.size_of_orderings
+ + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
+ + self.min_of_each_group_buf.1.capacity() / 8
+ }
+
+ fn supports_convert_to_state(&self) -> bool {
+ true
+ }
+
+ fn convert_to_state(
+ &self,
+ values: &[ArrayRef],
+ opt_filter: Option<&BooleanArray>,
+ ) -> Result<Vec<ArrayRef>> {
+ let mut result = values.to_vec();
+ match opt_filter {
+ Some(f) => {
+ result.push(Arc::new(f.clone()));
+ Ok(result)
+ }
+ None => {
+ result.push(Arc::new(BooleanArray::from(vec![true;
values[0].len()])));
+ Ok(result)
+ }
+ }
+ }
+}
#[derive(Debug)]
pub struct FirstValueAccumulator {
first: ScalarValue,
@@ -684,7 +1250,8 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs:
&LexOrdering) -> Vec<Sort
#[cfg(test)]
mod tests {
- use arrow::array::Int64Array;
+ use arrow::{array::Int64Array, compute::SortOptions, datatypes::Schema};
+ use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
use super::*;
@@ -823,4 +1390,177 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn test_frist_group_acc() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Int64, true),
+ Field::new("c", DataType::Int64, true),
+ Field::new("d", DataType::Int32, true),
+ Field::new("e", DataType::Boolean, true),
+ ]));
+
+ let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
+ expr: col("c", &schema).unwrap(),
+ options: SortOptions::default(),
+ }]);
+
+ let mut group_acc =
FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
+ sort_key,
+ true,
+ &DataType::Int64,
+ &[DataType::Int64],
+ )?;
+
+ let mut val_with_orderings = {
+ let mut val_with_orderings = Vec::<ArrayRef>::new();
+
+ let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3),
Some(-6)]));
+ let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
+
+ val_with_orderings.push(vals);
+ val_with_orderings.push(orderings);
+
+ val_with_orderings
+ };
+
+ group_acc.update_batch(
+ &val_with_orderings,
+ &[0, 1, 2, 1],
+ Some(&BooleanArray::from(vec![true, true, false, true])),
+ 3,
+ )?;
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ let state = group_acc.state(EmitTo::All)?;
+
+ let expected_state: Vec<Arc<dyn Array>> = vec![
+ Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+ Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
+ Arc::new(BooleanArray::from(vec![true, true, false])),
+ ];
+ assert_eq!(state, expected_state);
+
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ group_acc.merge_batch(
+ &state,
+ &[0, 1, 2],
+ Some(&BooleanArray::from(vec![true, false, false])),
+ 3,
+ )?;
+
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ val_with_orderings.clear();
+ val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
+ val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
+
+ group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
+
+ let binding = group_acc.evaluate(EmitTo::All)?;
+ let eval_result =
binding.as_any().downcast_ref::<Int64Array>().unwrap();
+
+ let expect: PrimitiveArray<Int64Type> =
+ Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
+
+ assert_eq!(eval_result, &expect);
+
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_frist_group_acc_size_of_ordering() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Int64, true),
+ Field::new("c", DataType::Int64, true),
+ Field::new("d", DataType::Int32, true),
+ Field::new("e", DataType::Boolean, true),
+ ]));
+
+ let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
+ expr: col("c", &schema).unwrap(),
+ options: SortOptions::default(),
+ }]);
+
+ let mut group_acc =
FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
+ sort_key,
+ true,
+ &DataType::Int64,
+ &[DataType::Int64],
+ )?;
+
+ let val_with_orderings = {
+ let mut val_with_orderings = Vec::<ArrayRef>::new();
+
+ let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3),
Some(-6)]));
+ let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
+
+ val_with_orderings.push(vals);
+ val_with_orderings.push(orderings);
+
+ val_with_orderings
+ };
+
+ for _ in 0..10 {
+ group_acc.update_batch(
+ &val_with_orderings,
+ &[0, 1, 2, 1],
+ Some(&BooleanArray::from(vec![true, true, false, true])),
+ 100,
+ )?;
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ group_acc.state(EmitTo::First(2))?;
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ let s = group_acc.state(EmitTo::All)?;
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None,
100)?;
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ group_acc.evaluate(EmitTo::First(2))?;
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+
+ group_acc.evaluate(EmitTo::All)?;
+ assert_eq!(
+ group_acc.size_of_orderings,
+ group_acc.compute_size_of_orderings()
+ );
+ }
+
+ Ok(())
+ }
}
diff --git a/datafusion/sqllogictest/test_files/group_by.slt
b/datafusion/sqllogictest/test_files/group_by.slt
index 5bf539e0b0..d9ef12496e 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -2665,6 +2665,47 @@ TUR [100.0, 75.0] 175
# test_reverse_aggregate_expr
# Some of the Aggregators can be reversed, by this way we can still run
aggregators without re-ordering
# that have contradictory requirements at first glance.
+
+statement ok
+CREATE TABLE null_group (
+ a INT, b INT, c INT, d INT
+) as VALUES
+ (6, 6, null, null),
+ (6, 6, 1, null),
+ (6, 6, null, 1)
+
+query III rowsort
+select c, d, first_value(a order by b) from null_group group by c, d;
+----
+1 NULL 6
+NULL 1 6
+NULL NULL 6
+
+
+
+statement ok
+CREATE TABLE first_null (
+ k INT,
+ val INT,
+ o int
+ ) as VALUES
+ (0, NULL, -9),
+ (0, 1, 1),
+ (1, 1, 1);
+
+query II rowsort
+select k, first_value(val order by o) IGNORE NULLS from first_null group by k;
+----
+0 1
+1 1
+
+query II rowsort
+select k, first_value(val order by o) respect NULLS from first_null group by k;
+----
+0 NULL
+1 1
+
+
query TT
EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts,
FIRST_VALUE(amount ORDER BY amount ASC) AS fv1,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]