This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e24b058911 Fix DistinctCount for timestamps with time zone (#10043)
e24b058911 is described below

commit e24b0589112edd65f2652d2bba9766c3cc18bc97
Author: Georgi Krastev <[email protected]>
AuthorDate: Thu Apr 11 15:36:02 2024 +0300

    Fix DistinctCount for timestamps with time zone (#10043)
    
    * Fix DistinctCount for timestamps with time zone
    
    Preserve the original data type in the aggregation state
    
    * Add tests for decimal count distinct
---
 .../src/aggregate/count_distinct/mod.rs            | 42 ++++++++++++----------
 .../src/aggregate/count_distinct/native.rs         | 15 ++++++--
 datafusion/sqllogictest/test_files/aggregate.slt   | 37 ++++++++++++++++---
 datafusion/sqllogictest/test_files/decimal.slt     | 11 ++++++
 4 files changed, 79 insertions(+), 26 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index 9c5605f495..ee63945eb2 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -109,12 +109,14 @@ impl AggregateExpr for DistinctCount {
             UInt16 => 
Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
             UInt32 => 
Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
             UInt64 => 
Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
-            Decimal128(_, _) => {
-                
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
-            }
-            Decimal256(_, _) => {
-                
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
-            }
+            dt @ Decimal128(_, _) => Box::new(
+                PrimitiveDistinctCountAccumulator::<Decimal128Type>::new()
+                    .with_data_type(dt.clone()),
+            ),
+            dt @ Decimal256(_, _) => Box::new(
+                PrimitiveDistinctCountAccumulator::<Decimal256Type>::new()
+                    .with_data_type(dt.clone()),
+            ),
 
             Date32 => 
Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
             Date64 => 
Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
@@ -130,18 +132,22 @@ impl AggregateExpr for DistinctCount {
             Time64(Nanosecond) => {
                 
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
             }
-            Timestamp(Microsecond, _) => 
Box::new(PrimitiveDistinctCountAccumulator::<
-                TimestampMicrosecondType,
-            >::new()),
-            Timestamp(Millisecond, _) => 
Box::new(PrimitiveDistinctCountAccumulator::<
-                TimestampMillisecondType,
-            >::new()),
-            Timestamp(Nanosecond, _) => 
Box::new(PrimitiveDistinctCountAccumulator::<
-                TimestampNanosecondType,
-            >::new()),
-            Timestamp(Second, _) => {
-                
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
-            }
+            dt @ Timestamp(Microsecond, _) => Box::new(
+                
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new()
+                    .with_data_type(dt.clone()),
+            ),
+            dt @ Timestamp(Millisecond, _) => Box::new(
+                
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new()
+                    .with_data_type(dt.clone()),
+            ),
+            dt @ Timestamp(Nanosecond, _) => Box::new(
+                
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new()
+                    .with_data_type(dt.clone()),
+            ),
+            dt @ Timestamp(Second, _) => Box::new(
+                PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new()
+                    .with_data_type(dt.clone()),
+            ),
 
             Float16 => 
Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
             Float32 => 
Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()),
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
index a44e8b772e..8f3ce8acfe 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
@@ -30,6 +30,7 @@ use ahash::RandomState;
 use arrow::array::ArrayRef;
 use arrow_array::types::ArrowPrimitiveType;
 use arrow_array::PrimitiveArray;
+use arrow_schema::DataType;
 
 use datafusion_common::cast::{as_list_array, as_primitive_array};
 use datafusion_common::utils::array_into_list_array;
@@ -45,6 +46,7 @@ where
     T::Native: Eq + Hash,
 {
     values: HashSet<T::Native, RandomState>,
+    data_type: DataType,
 }
 
 impl<T> PrimitiveDistinctCountAccumulator<T>
@@ -55,8 +57,14 @@ where
     pub(super) fn new() -> Self {
         Self {
             values: HashSet::default(),
+            data_type: T::DATA_TYPE,
         }
     }
+
+    pub(super) fn with_data_type(mut self, data_type: DataType) -> Self {
+        self.data_type = data_type;
+        self
+    }
 }
 
 impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
@@ -65,9 +73,10 @@ where
     T::Native: Eq + Hash,
 {
     fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
-        let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
-            self.values.iter().cloned(),
-        )) as ArrayRef;
+        let arr = Arc::new(
+            PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
+                .with_data_type(self.data_type.clone()),
+        );
         let list = Arc::new(array_into_list_array(arr));
         Ok(vec![ScalarValue::List(list)])
     }
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 4929ab485d..966236db27 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1876,18 +1876,22 @@ select
   arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros,
   arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis,
   arrow_cast(column1, 'Timestamp(Second, None)') as secs,
+  arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc,
+  arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc,
+  arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc,
+  arrow_cast(column1, 'Timestamp(Second, Some("UTC"))') as secs_utc,
   column2 as names,
   column3 as tag
 from t_source;
 
 # Demonstate the contents
-query PPPPTT
+query PPPPPPPPTT
 select * from t;
 ----
-2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 
2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 X
-2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 
2011-12-13T11:13:10 Row 1 X
-NULL NULL NULL NULL Row 2 Y
-2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 
2021-01-01T05:11:10 Row 3 Y
+2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 
2018-11-13T17:11:10.011 2018-11-13T17:11:10 2018-11-13T17:11:10.011375885Z 
2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 
X
+2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 
2011-12-13T11:13:10 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 
2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 X
+NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y
+2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 
2021-01-01T05:11:10 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 
2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 Y
 
 
 # aggregate_timestamps_sum
@@ -1933,6 +1937,17 @@ SELECT tag, max(nanos), max(micros), max(millis), 
max(secs) FROM t GROUP BY tag
 X 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 
2018-11-13T17:11:10.011 2018-11-13T17:11:10
 Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 
2021-01-01T05:11:10
 
+# aggregate_timestamps_count_distinct_with_tz
+query IIII
+SELECT count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT 
millis_utc), count(DISTINCT secs_utc) FROM t;
+----
+3 3 3 3
+
+query TIIII
+SELECT tag, count(DISTINCT nanos_utc), count(DISTINCT micros_utc), 
count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t GROUP BY tag ORDER 
BY tag;
+----
+X 2 2 2 2
+Y 1 1 1 1
 
 # aggregate_timestamps_avg
 statement error DataFusion error: Error during planning: No function matches 
the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You 
might need to add explicit type casts\.
@@ -2285,6 +2300,18 @@ select c2, avg(c1), arrow_typeof(avg(c1)) from d_table 
GROUP BY c2 ORDER BY c2
 A 110.0045 Decimal128(14, 7)
 B -100.0045 Decimal128(14, 7)
 
+# aggregate_decimal_count_distinct
+query I
+select count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table
+----
+4
+
+query TI
+select c2, count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table GROUP BY c2 
ORDER BY c2
+----
+A 2
+B 2
+
 # Use PostgresSQL dialect
 statement ok
 set datafusion.sql_parser.dialect = 'Postgres';
diff --git a/datafusion/sqllogictest/test_files/decimal.slt 
b/datafusion/sqllogictest/test_files/decimal.slt
index c220a5fc9a..3f75e42d93 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -720,5 +720,16 @@ select count(*),c1 from decimal256_simple group by c1 
order by c1;
 4 0.00004
 5 0.00005
 
+query I
+select count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple;
+----
+2
+
+query BI
+select c4, count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple 
GROUP BY c4 ORDER BY c4;
+----
+false 2
+true 2
+
 statement ok
 drop table decimal256_simple;

Reply via email to