This is an automated email from the ASF dual-hosted git repository. blaginin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 5e307b3372 Handle dicts for distinct count (#15871) 5e307b3372 is described below commit 5e307b337210cfb19e0b0d210a51811b22b46728 Author: Dmitrii Blaginin <dmit...@blaginin.me> AuthorDate: Thu Jun 5 17:02:09 2025 +0100 Handle dicts for distinct count (#15871) * Handle dicts for distinct count * Fix sqllogictests * Add bench * Fix no fix the bench * Do not panic if error type is bad * Add full bench query * Set the bench * Add dict of dict test * Fix tests * Rename method * Increase the grouping test * Increase the grouping test a bit more :) * Fix flakiness --------- Co-authored-by: Dmitrii Blaginin <blaginin@bmac.local> --- .../src/aggregate/count_distinct.rs | 2 + .../src/aggregate/count_distinct/dict.rs | 70 ++++++ datafusion/functions-aggregate/benches/count.rs | 46 +++- datafusion/functions-aggregate/src/count.rs | 275 ++++++++++++--------- datafusion/sqllogictest/test_files/aggregate.slt | 10 +- 5 files changed, 285 insertions(+), 118 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs index 7d772f7c64..25b4038229 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs @@ -16,9 +16,11 @@ // under the License. mod bytes; +mod dict; mod native; pub use bytes::BytesDistinctCountAccumulator; pub use bytes::BytesViewDistinctCountAccumulator; +pub use dict::DictionaryCountAccumulator; pub use native::FloatDistinctCountAccumulator; pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs new file mode 100644 index 0000000000..089d8d5acd --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::downcast_dictionary_array; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError}; +use datafusion_expr_common::accumulator::Accumulator; + +#[derive(Debug)] +pub struct DictionaryCountAccumulator { + inner: Box<dyn Accumulator>, +} + +impl DictionaryCountAccumulator { + pub fn new(inner: Box<dyn Accumulator>) -> Self { + Self { inner } + } +} + +impl Accumulator for DictionaryCountAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + let values: Vec<_> = values + .iter() + .map(|dict| { + downcast_dictionary_array! { + dict => { + let buff: BooleanArray = dict.occupancy().into(); + arrow::compute::filter( + dict.values(), + &buff + ).map_err(|e| arrow_datafusion_err!(e)) + }, + _ => internal_err!("DictionaryCountAccumulator only supports dictionary arrays") + } + }) + .collect::<Result<Vec<_>, _>>()?; + self.inner.update_batch(values.as_slice()) + } + + fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> { + self.inner.evaluate() + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> { + self.inner.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + self.inner.merge_batch(states) + } +} diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index d5abf6b8ac..cffa50bdda 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -17,15 +17,20 @@ use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; -use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; +use arrow::util::bench_util::{ + create_boolean_array, create_dict_from_values, create_primitive_array, + create_string_array_with_len, +}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, GroupsAccumulator, +}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; -fn prepare_accumulator() -> Box<dyn GroupsAccumulator> { +fn prepare_group_accumulator() -> Box<dyn GroupsAccumulator> { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); let accumulator_args = AccumulatorArgs { return_field: Field::new("f", DataType::Int64, true).into(), @@ -44,13 +49,34 @@ fn prepare_accumulator() -> Box<dyn GroupsAccumulator> { .unwrap() } +fn prepare_accumulator() -> Box<dyn Accumulator> { + let schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + let accumulator_args = AccumulatorArgs { + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + schema: &schema, + ignore_nulls: false, + ordering_req: &LexOrdering::default(), + is_reversed: false, + name: "COUNT(f)", + is_distinct: true, + exprs: &[col("f", &schema).unwrap()], + }; + let count_fn = Count::new(); + + count_fn.accumulator(accumulator_args).unwrap() +} + fn convert_to_state_bench( c: &mut Criterion, name: &str, values: ArrayRef, opt_filter: Option<&BooleanArray>, ) { - let accumulator = prepare_accumulator(); + let accumulator = prepare_group_accumulator(); c.bench_function(name, |b| { b.iter(|| { black_box( @@ -89,6 +115,18 @@ fn count_benchmark(c: &mut Criterion) { values, Some(&filter), ); + + let arr = create_string_array_with_len::<i32>(20, 0.0, 50); + let values = + Arc::new(create_dict_from_values::<Int32Type>(200_000, 0.8, &arr)) as ArrayRef; + + let mut accumulator = prepare_accumulator(); + c.bench_function("count low cardinality dict 20% nulls, no filter", |b| { + b.iter(|| { + #[allow(clippy::unit_arg)] + black_box(accumulator.update_batch(&[values.clone()]).unwrap()) + }) + }); } criterion_group!(benches, count_benchmark); diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index df31465e4a..f375a68d94 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -57,8 +57,8 @@ use datafusion_expr::{ Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, }; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ - BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, - PrimitiveDistinctCountAccumulator, + BytesDistinctCountAccumulator, DictionaryCountAccumulator, + FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; use datafusion_physical_expr_common::binary_map::OutputType; @@ -180,6 +180,107 @@ impl Count { } } } +fn get_count_accumulator(data_type: &DataType) -> Box<dyn Accumulator> { + match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new( + data_type, + )), + DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new( + data_type, + )), + DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new( + data_type, + )), + DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new( + data_type, + )), + DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new( + data_type, + )), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(data_type), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(data_type), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(data_type), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(data_type), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(data_type), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::<Float64Type>::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8)) + } + DataType::Utf8View => { + Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new( + OutputType::Binary, + )), + DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( + OutputType::BinaryView, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + } +} impl AggregateUDFImpl for Count { fn as_any(&self) -> &dyn std::any::Any { @@ -204,10 +305,15 @@ impl AggregateUDFImpl for Count { fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> { if args.is_distinct { + let dtype: DataType = match &args.input_fields[0].data_type() { + DataType::Dictionary(_, values_type) => (**values_type).clone(), + &dtype => dtype.clone(), + }; + Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_fields[0].data_type().clone(), true), + Field::new_list_field(dtype, true), false, ) .into()]) @@ -231,114 +337,13 @@ impl AggregateUDFImpl for Count { } let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - DataType::Int8 => Box::new( - PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type), - ), - DataType::Int16 => Box::new( - PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type), - ), - DataType::Int32 => Box::new( - PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type), - ), - DataType::Int64 => Box::new( - PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type), - ), - DataType::UInt8 => Box::new( - PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type), - ), - DataType::UInt16 => Box::new( - PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type), - ), - DataType::UInt32 => Box::new( - PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type), - ), - DataType::UInt64 => Box::new( - PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type), - ), - DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - DataType::Date32 => Box::new( - PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type), - ), - DataType::Date64 => Box::new( - PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type), - ), - DataType::Time32(TimeUnit::Millisecond) => Box::new( - PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new( - data_type, - ), - ), - DataType::Time32(TimeUnit::Second) => Box::new( - PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type), - ), - DataType::Time64(TimeUnit::Microsecond) => Box::new( - PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new( - data_type, - ), - ), - DataType::Time64(TimeUnit::Nanosecond) => Box::new( - PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type), - ), - DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Second, _) => Box::new( - PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type), - ), - DataType::Float16 => { - Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()) - } - DataType::Float32 => { - Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()) - } - DataType::Float64 => { - Box::new(FloatDistinctCountAccumulator::<Float64Type>::new()) - } - - DataType::Utf8 => { - Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8)) - } - DataType::Utf8View => { - Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) - } - DataType::LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8)) + Ok(match data_type { + DataType::Dictionary(_, values_type) => { + let inner = get_count_accumulator(values_type); + Box::new(DictionaryCountAccumulator::new(inner)) } - DataType::Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new( - OutputType::Binary, - )), - DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( - OutputType::BinaryView, - )), - DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: data_type.clone(), - }), + _ => get_count_accumulator(data_type), }) } @@ -758,7 +763,12 @@ impl Accumulator for DistinctCountAccumulator { #[cfg(test)] mod tests { use super::*; - use arrow::array::NullArray; + use arrow::array::{Int32Array, NullArray}; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; + use datafusion_expr::function::AccumulatorArgs; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::LexOrdering; + use std::sync::Arc; #[test] fn count_accumulator_nulls() -> Result<()> { @@ -767,4 +777,49 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn test_nested_dictionary() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "dict_col", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )), + ), + true, + )])); + + // Using Count UDAF's accumulator + let count = Count::new(); + let expr = Arc::new(Column::new("dict_col", 0)); + let args = AccumulatorArgs { + schema: &schema, + exprs: &[expr], + is_distinct: true, + name: "count", + ignore_nulls: false, + is_reversed: false, + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + ordering_req: &LexOrdering::default(), + }; + + let inner_dict = arrow::array::DictionaryArray::<Int32Type>::from_iter([ + "a", "b", "c", "d", "a", "b", + ]); + + let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]); + let dict_of_dict = arrow::array::DictionaryArray::<Int32Type>::try_new( + keys, + Arc::new(inner_dict), + )?; + + let mut acc = count.accumulator(args)?; + acc.update_batch(&[Arc::new(dict_of_dict)])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4))); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 86e429f903..ed77435d6a 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5013,18 +5013,20 @@ set datafusion.sql_parser.dialect = 'Generic'; ## Multiple distinct aggregates and dictionaries statement ok -create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); +create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); query IT -select * from dict_test; +select * from dict_test order by column1, column2; ---- +1 bar +1 foo 1 foo 2 bar query II -select count(distinct column1), count(distinct column2) from dict_test group by column1; +select count(distinct column1), count(distinct column2) from dict_test group by column1 order by column1; ---- -1 1 +1 2 1 1 statement ok --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org