This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3c3b22866a Fix incorrect results with multiple `COUNT(DISTINCT..)`
aggregates on dictionaries (#9679)
3c3b22866a is described below
commit 3c3b22866a7ece784208e9d499119b2e13399762
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Mar 19 11:38:25 2024 -0400
Fix incorrect results with multiple `COUNT(DISTINCT..)` aggregates on
dictionaries (#9679)
* Add test for multiple count distincts on a dictionary
* Fix accumulator merge bug
* Fix cleanup code
---
datafusion/common/src/scalar/mod.rs | 2 +-
.../src/aggregate/count_distinct/mod.rs | 32 ++++++++---
datafusion/sqllogictest/test_files/dictionary.slt | 67 ++++++++++++++++++++++
3 files changed, 93 insertions(+), 8 deletions(-)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 5ace44f24b..316624175e 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -1746,7 +1746,7 @@ impl ScalarValue {
}
/// Converts `Vec<ScalarValue>` where each element has type corresponding
to
- /// `data_type`, to a [`ListArray`].
+ /// `data_type`, to a single element [`ListArray`].
///
/// Example
/// ```
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index 71782fcc5f..fb5e771049 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -47,7 +47,7 @@ use crate::binary_map::OutputType;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
-/// Expression for a COUNT(DISTINCT) aggregation.
+/// Expression for a `COUNT(DISTINCT)` aggregation.
#[derive(Debug)]
pub struct DistinctCount {
/// Column name
@@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount {
use TimeUnit::*;
Ok(match &self.state_data_type {
+ // try and use a specialized accumulator if possible, otherwise
fall back to generic accumulator
Int8 =>
Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
Int16 =>
Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
Int32 =>
Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
@@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount {
OutputType::Binary,
)),
+ // Use the generic accumulator based on `ScalarValue` for all
other types
_ => Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
@@ -183,7 +185,11 @@ impl PartialEq<dyn Any> for DistinctCount {
}
/// General purpose distinct accumulator that works for any DataType by using
-/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
+/// [`ScalarValue`].
+///
+/// It stores intermediate results as a `ListArray`
+///
+/// Note that many types have specialized accumulators that are (much)
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
/// [`BytesDistinctCountAccumulator`]
#[derive(Debug)]
@@ -193,8 +199,9 @@ struct DistinctCountAccumulator {
}
impl DistinctCountAccumulator {
- // calculating the size for fixed length values, taking first batch size *
number of batches
- // This method is faster than .full_size(), however it is not suitable for
variable length values like strings or complex types
+ // calculating the size for fixed length values, taking first batch size *
+ // number of batches This method is faster than .full_size(), however it is
+ // not suitable for variable length values like strings or complex types
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
@@ -207,7 +214,8 @@ impl DistinctCountAccumulator {
+ std::mem::size_of::<DataType>()
}
- // calculates the size as accurate as possible, call to this method is
expensive
+ // calculates the size as accurately as possible. Note that calling this
+ // method is expensive
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
@@ -221,6 +229,7 @@ impl DistinctCountAccumulator {
}
impl Accumulator for DistinctCountAccumulator {
+ /// Returns the distinct values seen so far as (one element) ListArray.
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let scalars = self.values.iter().cloned().collect::<Vec<_>>();
let arr = ScalarValue::new_list(scalars.as_slice(),
&self.state_data_type);
@@ -246,6 +255,11 @@ impl Accumulator for DistinctCountAccumulator {
})
}
+ /// Merges multiple sets of distinct values into the current set.
+ ///
+ /// The input to this function is a `ListArray` with **multiple** rows,
+ /// where each row contains the values from a partial aggregate's phase
(e.g.
+ /// the result of calling `Self::state` on multiple accumulators).
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
@@ -253,8 +267,12 @@ impl Accumulator for DistinctCountAccumulator {
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let array = &states[0];
let list_array = array.as_list::<i32>();
- let inner_array = list_array.value(0);
- self.update_batch(&[inner_array])
+ for inner_array in list_array.iter() {
+ let inner_array = inner_array
+ .expect("counts are always non null, so are intermediate
results");
+ self.update_batch(&[inner_array])?;
+ }
+ Ok(())
}
fn evaluate(&mut self) -> Result<ScalarValue> {
diff --git a/datafusion/sqllogictest/test_files/dictionary.slt
b/datafusion/sqllogictest/test_files/dictionary.slt
index 002aade252..af7bf5cb16 100644
--- a/datafusion/sqllogictest/test_files/dictionary.slt
+++ b/datafusion/sqllogictest/test_files/dictionary.slt
@@ -280,3 +280,70 @@ ORDER BY
2023-12-20T01:20:00 1000 f2 foo
2023-12-20T01:30:00 1000 f1 32.0
2023-12-20T01:30:00 1000 f2 foo
+
+# Cleanup
+statement ok
+drop view m1;
+
+statement ok
+drop view m2;
+
+######
+# Create a table using UNION ALL to get 2 partitions (very important)
+######
+statement ok
+create table m3_source as
+ select * from (values('foo', 'bar', 1))
+ UNION ALL
+ select * from (values('foo', 'baz', 1));
+
+######
+# Now, create a table with the same data, but column2 has type
`Dictionary(Int32)` to trigger the fallback code
+######
+statement ok
+create table m3 as
+ select
+ column1,
+ arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2",
+ column3
+from m3_source;
+
+# there are two values in column2
+query T?I rowsort
+SELECT *
+FROM m3;
+----
+foo bar 1
+foo baz 1
+
+# There is 1 distinct value in column1
+query I
+SELECT count(distinct column1)
+FROM m3
+GROUP BY column3;
+----
+1
+
+# There are 2 distinct values in column2
+query I
+SELECT count(distinct column2)
+FROM m3
+GROUP BY column3;
+----
+2
+
+# Should still get the same results when querying in the same query
+query II
+SELECT count(distinct column1), count(distinct column2)
+FROM m3
+GROUP BY column3;
+----
+1 2
+
+
+# Cleanup
+statement ok
+drop table m3;
+
+statement ok
+drop table m3_source;