This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e24b058911 Fix DistinctCount for timestamps with time zone (#10043)
e24b058911 is described below
commit e24b0589112edd65f2652d2bba9766c3cc18bc97
Author: Georgi Krastev <[email protected]>
AuthorDate: Thu Apr 11 15:36:02 2024 +0300
Fix DistinctCount for timestamps with time zone (#10043)
* Fix DistinctCount for timestamps with time zone
Preserve the original data type in the aggregation state
* Add tests for decimal count distinct
---
.../src/aggregate/count_distinct/mod.rs | 42 ++++++++++++----------
.../src/aggregate/count_distinct/native.rs | 15 ++++++--
datafusion/sqllogictest/test_files/aggregate.slt | 37 ++++++++++++++++---
datafusion/sqllogictest/test_files/decimal.slt | 11 ++++++
4 files changed, 79 insertions(+), 26 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index 9c5605f495..ee63945eb2 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -109,12 +109,14 @@ impl AggregateExpr for DistinctCount {
UInt16 =>
Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
UInt32 =>
Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
UInt64 =>
Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
- Decimal128(_, _) => {
-
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
- }
- Decimal256(_, _) => {
-
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
- }
+ dt @ Decimal128(_, _) => Box::new(
+ PrimitiveDistinctCountAccumulator::<Decimal128Type>::new()
+ .with_data_type(dt.clone()),
+ ),
+ dt @ Decimal256(_, _) => Box::new(
+ PrimitiveDistinctCountAccumulator::<Decimal256Type>::new()
+ .with_data_type(dt.clone()),
+ ),
Date32 =>
Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
Date64 =>
Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
@@ -130,18 +132,22 @@ impl AggregateExpr for DistinctCount {
Time64(Nanosecond) => {
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
}
- Timestamp(Microsecond, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- TimestampMicrosecondType,
- >::new()),
- Timestamp(Millisecond, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- TimestampMillisecondType,
- >::new()),
- Timestamp(Nanosecond, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- TimestampNanosecondType,
- >::new()),
- Timestamp(Second, _) => {
-
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
- }
+ dt @ Timestamp(Microsecond, _) => Box::new(
+
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new()
+ .with_data_type(dt.clone()),
+ ),
+ dt @ Timestamp(Millisecond, _) => Box::new(
+
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new()
+ .with_data_type(dt.clone()),
+ ),
+ dt @ Timestamp(Nanosecond, _) => Box::new(
+
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new()
+ .with_data_type(dt.clone()),
+ ),
+ dt @ Timestamp(Second, _) => Box::new(
+ PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new()
+ .with_data_type(dt.clone()),
+ ),
Float16 =>
Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
Float32 =>
Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()),
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
index a44e8b772e..8f3ce8acfe 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
@@ -30,6 +30,7 @@ use ahash::RandomState;
use arrow::array::ArrayRef;
use arrow_array::types::ArrowPrimitiveType;
use arrow_array::PrimitiveArray;
+use arrow_schema::DataType;
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
@@ -45,6 +46,7 @@ where
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
+ data_type: DataType,
}
impl<T> PrimitiveDistinctCountAccumulator<T>
@@ -55,8 +57,14 @@ where
pub(super) fn new() -> Self {
Self {
values: HashSet::default(),
+ data_type: T::DATA_TYPE,
}
}
+
+ pub(super) fn with_data_type(mut self, data_type: DataType) -> Self {
+ self.data_type = data_type;
+ self
+ }
}
impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
@@ -65,9 +73,10 @@ where
T::Native: Eq + Hash,
{
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
- let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
- self.values.iter().cloned(),
- )) as ArrayRef;
+ let arr = Arc::new(
+ PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
+ .with_data_type(self.data_type.clone()),
+ );
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 4929ab485d..966236db27 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1876,18 +1876,22 @@ select
arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros,
arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis,
arrow_cast(column1, 'Timestamp(Second, None)') as secs,
+ arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc,
+ arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc,
+ arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc,
+ arrow_cast(column1, 'Timestamp(Second, Some("UTC"))') as secs_utc,
column2 as names,
column3 as tag
from t_source;
# Demonstate the contents
-query PPPPTT
+query PPPPPPPPTT
select * from t;
----
-2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375
2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 X
-2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123
2011-12-13T11:13:10 Row 1 X
-NULL NULL NULL NULL Row 2 Y
-2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432
2021-01-01T05:11:10 Row 3 Y
+2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375
2018-11-13T17:11:10.011 2018-11-13T17:11:10 2018-11-13T17:11:10.011375885Z
2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0
X
+2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123
2011-12-13T11:13:10 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z
2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 X
+NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y
+2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432
2021-01-01T05:11:10 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z
2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 Y
# aggregate_timestamps_sum
@@ -1933,6 +1937,17 @@ SELECT tag, max(nanos), max(micros), max(millis),
max(secs) FROM t GROUP BY tag
X 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375
2018-11-13T17:11:10.011 2018-11-13T17:11:10
Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432
2021-01-01T05:11:10
+# aggregate_timestamps_count_distinct_with_tz
+query IIII
+SELECT count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT
millis_utc), count(DISTINCT secs_utc) FROM t;
+----
+3 3 3 3
+
+query TIIII
+SELECT tag, count(DISTINCT nanos_utc), count(DISTINCT micros_utc),
count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t GROUP BY tag ORDER
BY tag;
+----
+X 2 2 2 2
+Y 1 1 1 1
# aggregate_timestamps_avg
statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You
might need to add explicit type casts\.
@@ -2285,6 +2300,18 @@ select c2, avg(c1), arrow_typeof(avg(c1)) from d_table
GROUP BY c2 ORDER BY c2
A 110.0045 Decimal128(14, 7)
B -100.0045 Decimal128(14, 7)
+# aggregate_decimal_count_distinct
+query I
+select count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table
+----
+4
+
+query TI
+select c2, count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table GROUP BY c2
ORDER BY c2
+----
+A 2
+B 2
+
# Use PostgresSQL dialect
statement ok
set datafusion.sql_parser.dialect = 'Postgres';
diff --git a/datafusion/sqllogictest/test_files/decimal.slt
b/datafusion/sqllogictest/test_files/decimal.slt
index c220a5fc9a..3f75e42d93 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -720,5 +720,16 @@ select count(*),c1 from decimal256_simple group by c1
order by c1;
4 0.00004
5 0.00005
+query I
+select count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple;
+----
+2
+
+query BI
+select c4, count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple
GROUP BY c4 ORDER BY c4;
+----
+false 2
+true 2
+
statement ok
drop table decimal256_simple;