This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d3f63728d2 Change `array_agg` to return `null` on no input rather than
empty list (#11299)
d3f63728d2 is described below
commit d3f63728d222cc5cf30cf03a12ec9a0b41399b18
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Jul 11 07:32:03 2024 +0800
Change `array_agg` to return `null` on no input rather than empty list
(#11299)
* change array agg semantic for empty result
Signed-off-by: jayzhan211 <[email protected]>
* return null
Signed-off-by: jayzhan211 <[email protected]>
* fix test
Signed-off-by: jayzhan211 <[email protected]>
* fix order sensitive
Signed-off-by: jayzhan211 <[email protected]>
* fix test
Signed-off-by: jayzhan211 <[email protected]>
* add more test
Signed-off-by: jayzhan211 <[email protected]>
* fix null
Signed-off-by: jayzhan211 <[email protected]>
* fix multi-phase case
Signed-off-by: jayzhan211 <[email protected]>
* add comment
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* fix clone
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/common/src/scalar/mod.rs | 10 ++
datafusion/core/tests/dataframe/mod.rs | 2 +-
datafusion/core/tests/sql/aggregates.rs | 2 +-
datafusion/expr/src/aggregate_function.rs | 2 +-
.../physical-expr/src/aggregate/array_agg.rs | 17 ++-
.../src/aggregate/array_agg_distinct.rs | 11 +-
.../src/aggregate/array_agg_ordered.rs | 12 +-
datafusion/physical-expr/src/aggregate/build_in.rs | 4 +-
datafusion/sqllogictest/test_files/aggregate.slt | 155 +++++++++++++++------
9 files changed, 161 insertions(+), 54 deletions(-)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index c8f21788cb..6c03e8698e 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -1984,6 +1984,16 @@ impl ScalarValue {
Self::new_list(values, data_type, true)
}
+ /// Create ListArray with Null with specific data type
+ ///
+ /// - new_null_list(i32, nullable, 1): `ListArray[NULL]`
+ pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize)
-> Self {
+ let data_type = DataType::List(Field::new_list_field(data_type,
nullable).into());
+ Self::List(Arc::new(ListArray::from(ArrayData::new_null(
+ &data_type, null_len,
+ ))))
+ }
+
/// Converts `IntoIterator<Item = ScalarValue>` where each element has
type corresponding to
/// `data_type`, to a [`ListArray`].
///
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index 2d1904d9e1..f1d57c4429 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -1388,7 +1388,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
let expected = vec![
"Projection: shapes.shape_id [shape_id:UInt32]",
" Unnest: lists[shape_id2] structs[] [shape_id:UInt32,
shape_id2:UInt32;N]",
- " Aggregate: groupBy=[[shapes.shape_id]],
aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32,
shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} })]",
+ " Aggregate: groupBy=[[shapes.shape_id]],
aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32,
shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
];
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index e503b74992..86032dc9bc 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, false),
- false
+ true
),])
);
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index 23e98714df..3cae78eaed 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -118,7 +118,7 @@ impl AggregateFunction {
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
- AggregateFunction::ArrayAgg => Ok(false),
+ AggregateFunction::ArrayAgg => Ok(true),
}
}
}
diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs
b/datafusion/physical-expr/src/aggregate/array_agg.rs
index 634a0a0179..38a9738029 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg.rs
@@ -71,7 +71,7 @@ impl AggregateExpr for ArrayAgg {
&self.name,
// This should be the same as return type of
AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
- false,
+ true,
))
}
@@ -86,7 +86,7 @@ impl AggregateExpr for ArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
- false,
+ true,
)])
}
@@ -137,8 +137,11 @@ impl Accumulator for ArrayAggAccumulator {
return Ok(());
}
assert!(values.len() == 1, "array_agg can only take 1 param!");
+
let val = Arc::clone(&values[0]);
- self.values.push(val);
+ if val.len() > 0 {
+ self.values.push(val);
+ }
Ok(())
}
@@ -162,13 +165,15 @@ impl Accumulator for ArrayAggAccumulator {
fn evaluate(&mut self) -> Result<ScalarValue> {
// Transform Vec<ListArr> to ListArr
-
let element_arrays: Vec<&dyn Array> =
self.values.iter().map(|a| a.as_ref()).collect();
if element_arrays.is_empty() {
- let arr = ScalarValue::new_list(&[], &self.datatype,
self.nullable);
- return Ok(ScalarValue::List(arr));
+ return Ok(ScalarValue::new_null_list(
+ self.datatype.clone(),
+ self.nullable,
+ 1,
+ ));
}
let concated_array = arrow::compute::concat(&element_arrays)?;
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index a59d85e84a..368d11d742 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -75,7 +75,7 @@ impl AggregateExpr for DistinctArrayAgg {
&self.name,
// This should be the same as return type of
AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
- false,
+ true,
))
}
@@ -90,7 +90,7 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
- false,
+ true,
)])
}
@@ -165,6 +165,13 @@ impl Accumulator for DistinctArrayAggAccumulator {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
+ if values.is_empty() {
+ return Ok(ScalarValue::new_null_list(
+ self.datatype.clone(),
+ self.nullable,
+ 1,
+ ));
+ }
let arr = ScalarValue::new_list(&values, &self.datatype,
self.nullable);
Ok(ScalarValue::List(arr))
}
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
index a64d97637c..d44811192f 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
@@ -92,7 +92,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
&self.name,
// This should be the same as return type of
AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), self.nullable),
- false,
+ true,
))
}
@@ -111,7 +111,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), self.nullable),
- false, // This should be the same as field()
+ true, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req,
&self.order_by_data_types);
fields.push(Field::new_list(
@@ -309,6 +309,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
}
fn evaluate(&mut self) -> Result<ScalarValue> {
+ if self.values.is_empty() {
+ return Ok(ScalarValue::new_null_list(
+ self.datatypes[0].clone(),
+ self.nullable,
+ 1,
+ ));
+ }
+
let values = self.values.clone();
let array = if self.reverse {
ScalarValue::new_list_from_iter(
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index d4cd3d51d1..68c9b4859f 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -147,7 +147,7 @@ mod tests {
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
- false,
+ true,
),
result_agg_phy_exprs.field().unwrap()
);
@@ -167,7 +167,7 @@ mod tests {
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
- false,
+ true,
),
result_agg_phy_exprs.field().unwrap()
);
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index e891093c81..7dd1ea82b3 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1694,7 +1694,7 @@ SELECT array_agg(c13) FROM (SELECT * FROM
aggregate_test_100 ORDER BY c13 LIMIT
query ?
SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test
----
-[]
+NULL
# csv_query_array_agg_one
query ?
@@ -1753,31 +1753,12 @@ NULL 4 29 1.260869565217 123 -117 23
NULL 5 -194 -13.857142857143 118 -101 14
NULL NULL 781 7.81 125 -117 100
-# TODO: array_agg_distinct output is non-deterministic -- rewrite with
array_sort(list_sort)
-# unnest is also not available, so manually unnesting via CROSS JOIN
-# additional count(1) forces array_agg_distinct instead of array_agg over
aggregated by c2 data
-#
+# select with count to forces array_agg_distinct function, since single
distinct expression is converted to group by by optimizer
# csv_query_array_agg_distinct
-query III
-WITH indices AS (
- SELECT 1 AS idx UNION ALL
- SELECT 2 AS idx UNION ALL
- SELECT 3 AS idx UNION ALL
- SELECT 4 AS idx UNION ALL
- SELECT 5 AS idx
-)
-SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len,
dummy
-FROM (
- SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM
aggregate_test_100
-) data
- CROSS JOIN indices
-ORDER BY 1
-----
-1 5 100
-2 5 100
-3 5 100
-4 5 100
-5 5 100
+query ?I
+SELECT array_sort(array_agg(distinct c2)), count(1) FROM aggregate_test_100
+----
+[1, 2, 3, 4, 5] 100
# aggregate_time_min_and_max
query TT
@@ -2732,6 +2713,16 @@ SELECT COUNT(DISTINCT c1) FROM test
# TODO: aggregate_with_alias
+# test_approx_percentile_cont_decimal_support
+query TI
+SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM
aggregate_test_100 GROUP BY 1 ORDER BY 1
+----
+a 4
+b 5
+c 4
+d 4
+e 4
+
# array_agg_zero
query ?
SELECT ARRAY_AGG([])
@@ -2744,28 +2735,114 @@ SELECT ARRAY_AGG([1])
----
[[1]]
-# test_approx_percentile_cont_decimal_support
-query TI
-SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM
aggregate_test_100 GROUP BY 1 ORDER BY 1
+# test array_agg with no row qualified
+statement ok
+create table t(a int, b float, c bigint) as values (1, 1.2, 2);
+
+# returns NULL, follows DuckDB's behaviour
+query ?
+select array_agg(a) from t where a > 2;
----
-a 4
-b 5
-c 4
-d 4
-e 4
+NULL
+query ?
+select array_agg(b) from t where b > 3.1;
+----
+NULL
-# array_agg_zero
query ?
-SELECT ARRAY_AGG([]);
+select array_agg(c) from t where c > 3;
----
-[[]]
+NULL
-# array_agg_one
+query ?I
+select array_agg(c), count(1) from t where c > 3;
+----
+NULL 0
+
+# returns 0 rows if group by is applied, follows DuckDB's behaviour
query ?
-SELECT ARRAY_AGG([1]);
+select array_agg(a) from t where a > 3 group by a;
----
-[[1]]
+
+query ?I
+select array_agg(a), count(1) from t where a > 3 group by a;
+----
+
+# returns NULL, follows DuckDB's behaviour
+query ?
+select array_agg(distinct a) from t where a > 3;
+----
+NULL
+
+query ?I
+select array_agg(distinct a), count(1) from t where a > 3;
+----
+NULL 0
+
+# returns 0 rows if group by is applied, follows DuckDB's behaviour
+query ?
+select array_agg(distinct a) from t where a > 3 group by a;
+----
+
+query ?I
+select array_agg(distinct a), count(1) from t where a > 3 group by a;
+----
+
+# test order sensitive array agg
+query ?
+select array_agg(a order by a) from t where a > 3;
+----
+NULL
+
+query ?
+select array_agg(a order by a) from t where a > 3 group by a;
+----
+
+query ?I
+select array_agg(a order by a), count(1) from t where a > 3 group by a;
+----
+
+statement ok
+drop table t;
+
+# test with no values
+statement ok
+create table t(a int, b float, c bigint);
+
+query ?
+select array_agg(a) from t;
+----
+NULL
+
+query ?
+select array_agg(b) from t;
+----
+NULL
+
+query ?
+select array_agg(c) from t;
+----
+NULL
+
+query ?I
+select array_agg(distinct a), count(1) from t;
+----
+NULL 0
+
+query ?I
+select array_agg(distinct b), count(1) from t;
+----
+NULL 0
+
+query ?I
+select array_agg(distinct b), count(1) from t;
+----
+NULL 0
+
+statement ok
+drop table t;
+
# array_agg_i32
statement ok
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]