This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d3f63728d2 Change `array_agg` to return `null` on no input rather than 
empty list (#11299)
d3f63728d2 is described below

commit d3f63728d222cc5cf30cf03a12ec9a0b41399b18
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Jul 11 07:32:03 2024 +0800

    Change `array_agg` to return `null` on no input rather than empty list 
(#11299)
    
    * change array agg semantic for empty result
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * return null
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix order sensitive
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add more test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix null
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix multi-phase case
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add comment
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix clone
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/common/src/scalar/mod.rs                |  10 ++
 datafusion/core/tests/dataframe/mod.rs             |   2 +-
 datafusion/core/tests/sql/aggregates.rs            |   2 +-
 datafusion/expr/src/aggregate_function.rs          |   2 +-
 .../physical-expr/src/aggregate/array_agg.rs       |  17 ++-
 .../src/aggregate/array_agg_distinct.rs            |  11 +-
 .../src/aggregate/array_agg_ordered.rs             |  12 +-
 datafusion/physical-expr/src/aggregate/build_in.rs |   4 +-
 datafusion/sqllogictest/test_files/aggregate.slt   | 155 +++++++++++++++------
 9 files changed, 161 insertions(+), 54 deletions(-)

diff --git a/datafusion/common/src/scalar/mod.rs 
b/datafusion/common/src/scalar/mod.rs
index c8f21788cb..6c03e8698e 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -1984,6 +1984,16 @@ impl ScalarValue {
         Self::new_list(values, data_type, true)
     }
 
+    /// Create ListArray with Null with specific data type
+    ///
+    /// - new_null_list(i32, nullable, 1): `ListArray[NULL]`
+    pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize) 
-> Self {
+        let data_type = DataType::List(Field::new_list_field(data_type, 
nullable).into());
+        Self::List(Arc::new(ListArray::from(ArrayData::new_null(
+            &data_type, null_len,
+        ))))
+    }
+
     /// Converts `IntoIterator<Item = ScalarValue>` where each element has 
type corresponding to
     /// `data_type`, to a [`ListArray`].
     ///
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index 2d1904d9e1..f1d57c4429 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -1388,7 +1388,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
     let expected = vec![
         "Projection: shapes.shape_id [shape_id:UInt32]",
         "  Unnest: lists[shape_id2] structs[] [shape_id:UInt32, 
shape_id2:UInt32;N]",
-        "    Aggregate: groupBy=[[shapes.shape_id]], 
aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, 
shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, 
dict_id: 0, dict_is_ordered: false, metadata: {} })]",
+        "    Aggregate: groupBy=[[shapes.shape_id]], 
aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, 
shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, 
dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
         "      TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
     ];
 
diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index e503b74992..86032dc9bc 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
         Schema::new(vec![Field::new_list(
             "ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
             Field::new("item", DataType::UInt32, false),
-            false
+            true
         ),])
     );
 
diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index 23e98714df..3cae78eaed 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -118,7 +118,7 @@ impl AggregateFunction {
     pub fn nullable(&self) -> Result<bool> {
         match self {
             AggregateFunction::Max | AggregateFunction::Min => Ok(true),
-            AggregateFunction::ArrayAgg => Ok(false),
+            AggregateFunction::ArrayAgg => Ok(true),
         }
     }
 }
diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs 
b/datafusion/physical-expr/src/aggregate/array_agg.rs
index 634a0a0179..38a9738029 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg.rs
@@ -71,7 +71,7 @@ impl AggregateExpr for ArrayAgg {
             &self.name,
             // This should be the same as return type of 
AggregateFunction::ArrayAgg
             Field::new("item", self.input_data_type.clone(), self.nullable),
-            false,
+            true,
         ))
     }
 
@@ -86,7 +86,7 @@ impl AggregateExpr for ArrayAgg {
         Ok(vec![Field::new_list(
             format_state_name(&self.name, "array_agg"),
             Field::new("item", self.input_data_type.clone(), self.nullable),
-            false,
+            true,
         )])
     }
 
@@ -137,8 +137,11 @@ impl Accumulator for ArrayAggAccumulator {
             return Ok(());
         }
         assert!(values.len() == 1, "array_agg can only take 1 param!");
+
         let val = Arc::clone(&values[0]);
-        self.values.push(val);
+        if val.len() > 0 {
+            self.values.push(val);
+        }
         Ok(())
     }
 
@@ -162,13 +165,15 @@ impl Accumulator for ArrayAggAccumulator {
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
         // Transform Vec<ListArr> to ListArr
-
         let element_arrays: Vec<&dyn Array> =
             self.values.iter().map(|a| a.as_ref()).collect();
 
         if element_arrays.is_empty() {
-            let arr = ScalarValue::new_list(&[], &self.datatype, 
self.nullable);
-            return Ok(ScalarValue::List(arr));
+            return Ok(ScalarValue::new_null_list(
+                self.datatype.clone(),
+                self.nullable,
+                1,
+            ));
         }
 
         let concated_array = arrow::compute::concat(&element_arrays)?;
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs 
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index a59d85e84a..368d11d742 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -75,7 +75,7 @@ impl AggregateExpr for DistinctArrayAgg {
             &self.name,
             // This should be the same as return type of 
AggregateFunction::ArrayAgg
             Field::new("item", self.input_data_type.clone(), self.nullable),
-            false,
+            true,
         ))
     }
 
@@ -90,7 +90,7 @@ impl AggregateExpr for DistinctArrayAgg {
         Ok(vec![Field::new_list(
             format_state_name(&self.name, "distinct_array_agg"),
             Field::new("item", self.input_data_type.clone(), self.nullable),
-            false,
+            true,
         )])
     }
 
@@ -165,6 +165,13 @@ impl Accumulator for DistinctArrayAggAccumulator {
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
         let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
+        if values.is_empty() {
+            return Ok(ScalarValue::new_null_list(
+                self.datatype.clone(),
+                self.nullable,
+                1,
+            ));
+        }
         let arr = ScalarValue::new_list(&values, &self.datatype, 
self.nullable);
         Ok(ScalarValue::List(arr))
     }
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs 
b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
index a64d97637c..d44811192f 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
@@ -92,7 +92,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
             &self.name,
             // This should be the same as return type of 
AggregateFunction::ArrayAgg
             Field::new("item", self.input_data_type.clone(), self.nullable),
-            false,
+            true,
         ))
     }
 
@@ -111,7 +111,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
         let mut fields = vec![Field::new_list(
             format_state_name(&self.name, "array_agg"),
             Field::new("item", self.input_data_type.clone(), self.nullable),
-            false, // This should be the same as field()
+            true, // This should be the same as field()
         )];
         let orderings = ordering_fields(&self.ordering_req, 
&self.order_by_data_types);
         fields.push(Field::new_list(
@@ -309,6 +309,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
     }
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
+        if self.values.is_empty() {
+            return Ok(ScalarValue::new_null_list(
+                self.datatypes[0].clone(),
+                self.nullable,
+                1,
+            ));
+        }
+
         let values = self.values.clone();
         let array = if self.reverse {
             ScalarValue::new_list_from_iter(
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index d4cd3d51d1..68c9b4859f 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -147,7 +147,7 @@ mod tests {
                         Field::new_list(
                             "c1",
                             Field::new("item", data_type.clone(), true),
-                            false,
+                            true,
                         ),
                         result_agg_phy_exprs.field().unwrap()
                     );
@@ -167,7 +167,7 @@ mod tests {
                         Field::new_list(
                             "c1",
                             Field::new("item", data_type.clone(), true),
-                            false,
+                            true,
                         ),
                         result_agg_phy_exprs.field().unwrap()
                     );
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index e891093c81..7dd1ea82b3 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1694,7 +1694,7 @@ SELECT array_agg(c13) FROM (SELECT * FROM 
aggregate_test_100 ORDER BY c13 LIMIT
 query ?
 SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test
 ----
-[]
+NULL
 
 # csv_query_array_agg_one
 query ?
@@ -1753,31 +1753,12 @@ NULL 4 29 1.260869565217 123 -117 23
 NULL 5 -194 -13.857142857143 118 -101 14
 NULL NULL 781 7.81 125 -117 100
 
-# TODO: array_agg_distinct output is non-deterministic -- rewrite with 
array_sort(list_sort)
-#       unnest is also not available, so manually unnesting via CROSS JOIN
-# additional count(1) forces array_agg_distinct instead of array_agg over 
aggregated by c2 data
-#
+# select with count to forces array_agg_distinct function, since single 
distinct expression is converted to group by by optimizer
 # csv_query_array_agg_distinct
-query III
-WITH indices AS (
-  SELECT 1 AS idx UNION ALL
-  SELECT 2 AS idx UNION ALL
-  SELECT 3 AS idx UNION ALL
-  SELECT 4 AS idx UNION ALL
-  SELECT 5 AS idx
-)
-SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, 
dummy
-FROM (
-  SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM 
aggregate_test_100
-) data
-  CROSS JOIN indices
-ORDER BY 1
-----
-1 5 100
-2 5 100
-3 5 100
-4 5 100
-5 5 100
+query ?I
+SELECT array_sort(array_agg(distinct c2)), count(1) FROM aggregate_test_100
+----
+[1, 2, 3, 4, 5] 100
 
 # aggregate_time_min_and_max
 query TT
@@ -2732,6 +2713,16 @@ SELECT COUNT(DISTINCT c1) FROM test
 
 # TODO: aggregate_with_alias
 
+# test_approx_percentile_cont_decimal_support
+query TI
+SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM 
aggregate_test_100 GROUP BY 1 ORDER BY 1
+----
+a 4
+b 5
+c 4
+d 4
+e 4
+
 # array_agg_zero
 query ?
 SELECT ARRAY_AGG([])
@@ -2744,28 +2735,114 @@ SELECT ARRAY_AGG([1])
 ----
 [[1]]
 
-# test_approx_percentile_cont_decimal_support
-query TI
-SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM 
aggregate_test_100 GROUP BY 1 ORDER BY 1
+# test array_agg with no row qualified
+statement ok
+create table t(a int, b float, c bigint) as values (1, 1.2, 2);
+
+# returns NULL, follows DuckDB's behaviour
+query ?
+select array_agg(a) from t where a > 2;
 ----
-a 4
-b 5
-c 4
-d 4
-e 4
+NULL
 
+query ?
+select array_agg(b) from t where b > 3.1;
+----
+NULL
 
-# array_agg_zero
 query ?
-SELECT ARRAY_AGG([]);
+select array_agg(c) from t where c > 3;
 ----
-[[]]
+NULL
 
-# array_agg_one
+query ?I
+select array_agg(c), count(1) from t where c > 3;
+----
+NULL 0
+
+# returns 0 rows if group by is applied, follows DuckDB's behaviour
 query ?
-SELECT ARRAY_AGG([1]);
+select array_agg(a) from t where a > 3 group by a;
 ----
-[[1]]
+
+query ?I
+select array_agg(a), count(1) from t where a > 3 group by a;
+----
+
+# returns NULL, follows DuckDB's behaviour
+query ?
+select array_agg(distinct a) from t where a > 3;
+----
+NULL
+
+query ?I
+select array_agg(distinct a), count(1) from t where a > 3;
+----
+NULL 0
+
+# returns 0 rows if group by is applied, follows DuckDB's behaviour
+query ?
+select array_agg(distinct a) from t where a > 3 group by a;
+----
+
+query ?I
+select array_agg(distinct a), count(1) from t where a > 3 group by a;
+----
+
+# test order sensitive array agg
+query ?
+select array_agg(a order by a) from t where a > 3;
+----
+NULL
+
+query ?
+select array_agg(a order by a) from t where a > 3 group by a;
+----
+
+query ?I
+select array_agg(a order by a), count(1) from t where a > 3 group by a;
+----
+
+statement ok
+drop table t;
+
+# test with no values
+statement ok
+create table t(a int, b float, c bigint);
+
+query ?
+select array_agg(a) from t;
+----
+NULL
+
+query ?
+select array_agg(b) from t;
+----
+NULL
+
+query ?
+select array_agg(c) from t;
+----
+NULL
+
+query ?I
+select array_agg(distinct a), count(1) from t;
+----
+NULL 0
+
+query ?I
+select array_agg(distinct b), count(1) from t;
+----
+NULL 0
+
+query ?I
+select array_agg(distinct b), count(1) from t;
+----
+NULL 0
+
+statement ok
+drop table t;
+
 
 # array_agg_i32
 statement ok


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to