This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c3b22866a Fix incorrect results with multiple `COUNT(DISTINCT..)` 
aggregates on dictionaries (#9679)
3c3b22866a is described below

commit 3c3b22866a7ece784208e9d499119b2e13399762
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Mar 19 11:38:25 2024 -0400

    Fix incorrect results with multiple `COUNT(DISTINCT..)` aggregates on 
dictionaries (#9679)
    
    * Add test for multiple count distincts on a dictionary
    
    * Fix accumulator merge bug
    
    * Fix cleanup code
---
 datafusion/common/src/scalar/mod.rs                |  2 +-
 .../src/aggregate/count_distinct/mod.rs            | 32 ++++++++---
 datafusion/sqllogictest/test_files/dictionary.slt  | 67 ++++++++++++++++++++++
 3 files changed, 93 insertions(+), 8 deletions(-)

diff --git a/datafusion/common/src/scalar/mod.rs 
b/datafusion/common/src/scalar/mod.rs
index 5ace44f24b..316624175e 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -1746,7 +1746,7 @@ impl ScalarValue {
     }
 
     /// Converts `Vec<ScalarValue>` where each element has type corresponding 
to
-    /// `data_type`, to a [`ListArray`].
+    /// `data_type`, to a single element [`ListArray`].
     ///
     /// Example
     /// ```
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index 71782fcc5f..fb5e771049 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -47,7 +47,7 @@ use crate::binary_map::OutputType;
 use crate::expressions::format_state_name;
 use crate::{AggregateExpr, PhysicalExpr};
 
-/// Expression for a COUNT(DISTINCT) aggregation.
+/// Expression for a `COUNT(DISTINCT)` aggregation.
 #[derive(Debug)]
 pub struct DistinctCount {
     /// Column name
@@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount {
         use TimeUnit::*;
 
         Ok(match &self.state_data_type {
+            // try and use a specialized accumulator if possible, otherwise 
fall back to generic accumulator
             Int8 => 
Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
             Int16 => 
Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
             Int32 => 
Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
@@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount {
                 OutputType::Binary,
             )),
 
+            // Use the generic accumulator based on `ScalarValue` for all 
other types
             _ => Box::new(DistinctCountAccumulator {
                 values: HashSet::default(),
                 state_data_type: self.state_data_type.clone(),
@@ -183,7 +185,11 @@ impl PartialEq<dyn Any> for DistinctCount {
 }
 
 /// General purpose distinct accumulator that works for any DataType by using
-/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
+/// [`ScalarValue`].
+///
+/// It stores intermediate results as a `ListArray`
+///
+/// Note that many types have specialized accumulators that are (much)
 /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
 /// [`BytesDistinctCountAccumulator`]
 #[derive(Debug)]
@@ -193,8 +199,9 @@ struct DistinctCountAccumulator {
 }
 
 impl DistinctCountAccumulator {
-    // calculating the size for fixed length values, taking first batch size * 
number of batches
-    // This method is faster than .full_size(), however it is not suitable for 
variable length values like strings or complex types
+    // calculating the size for fixed length values, taking first batch size *
+    // number of batches This method is faster than .full_size(), however it is
+    // not suitable for variable length values like strings or complex types
     fn fixed_size(&self) -> usize {
         std::mem::size_of_val(self)
             + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
@@ -207,7 +214,8 @@ impl DistinctCountAccumulator {
             + std::mem::size_of::<DataType>()
     }
 
-    // calculates the size as accurate as possible, call to this method is 
expensive
+    // calculates the size as accurately as possible. Note that calling this
+    // method is expensive
     fn full_size(&self) -> usize {
         std::mem::size_of_val(self)
             + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
@@ -221,6 +229,7 @@ impl DistinctCountAccumulator {
 }
 
 impl Accumulator for DistinctCountAccumulator {
+    /// Returns the distinct values seen so far as (one element) ListArray.
     fn state(&mut self) -> Result<Vec<ScalarValue>> {
         let scalars = self.values.iter().cloned().collect::<Vec<_>>();
         let arr = ScalarValue::new_list(scalars.as_slice(), 
&self.state_data_type);
@@ -246,6 +255,11 @@ impl Accumulator for DistinctCountAccumulator {
         })
     }
 
+    /// Merges multiple sets of distinct values into the current set.
+    ///
+    /// The input to this function is a `ListArray` with **multiple** rows,
+    /// where each row contains the values from a partial aggregate's phase 
(e.g.
+    /// the result of calling `Self::state` on multiple accumulators).
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
         if states.is_empty() {
             return Ok(());
@@ -253,8 +267,12 @@ impl Accumulator for DistinctCountAccumulator {
         assert_eq!(states.len(), 1, "array_agg states must be singleton!");
         let array = &states[0];
         let list_array = array.as_list::<i32>();
-        let inner_array = list_array.value(0);
-        self.update_batch(&[inner_array])
+        for inner_array in list_array.iter() {
+            let inner_array = inner_array
+                .expect("counts are always non null, so are intermediate 
results");
+            self.update_batch(&[inner_array])?;
+        }
+        Ok(())
     }
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
diff --git a/datafusion/sqllogictest/test_files/dictionary.slt 
b/datafusion/sqllogictest/test_files/dictionary.slt
index 002aade252..af7bf5cb16 100644
--- a/datafusion/sqllogictest/test_files/dictionary.slt
+++ b/datafusion/sqllogictest/test_files/dictionary.slt
@@ -280,3 +280,70 @@ ORDER BY
 2023-12-20T01:20:00 1000 f2 foo
 2023-12-20T01:30:00 1000 f1 32.0
 2023-12-20T01:30:00 1000 f2 foo
+
+# Cleanup
+statement ok
+drop view m1;
+
+statement ok
+drop view m2;
+
+######
+# Create a table using UNION ALL to get 2 partitions (very important)
+######
+statement ok
+create table m3_source as
+    select * from (values('foo', 'bar', 1))
+    UNION ALL
+    select * from (values('foo', 'baz', 1));
+
+######
+# Now, create a table with the same data, but column2 has type 
`Dictionary(Int32)` to trigger the fallback code
+######
+statement ok
+create table m3 as
+  select
+    column1,
+    arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2",
+    column3
+from m3_source;
+
+# there are two values in column2
+query T?I rowsort
+SELECT *
+FROM m3;
+----
+foo bar 1
+foo baz 1
+
+# There is 1 distinct value in column1
+query I
+SELECT count(distinct column1)
+FROM m3
+GROUP BY column3;
+----
+1
+
+# There are 2 distinct values in column2
+query I
+SELECT count(distinct column2)
+FROM m3
+GROUP BY column3;
+----
+2
+
+# Should still get the same results when querying in the same query
+query II
+SELECT count(distinct column1), count(distinct column2)
+FROM m3
+GROUP BY column3;
+----
+1 2
+
+
+# Cleanup
+statement ok
+drop table m3;
+
+statement ok
+drop table m3_source;

Reply via email to