This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new fc84a639fc Support List for Array aggregate order and distinct (#9234)
fc84a639fc is described below

commit fc84a639fca7716e529384c0e919fb90b75139da
Author: Jay Zhan <[email protected]>
AuthorDate: Tue Feb 20 15:41:42 2024 +0800

    Support List for Array aggregate order and distinct (#9234)
    
    * first draft
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix convert_first_level_array_to_scalar_vec
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add doc
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix nth
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * support distinct
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm convert_first_level_array_to_scalar_vec
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add doc and assertion
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix doc
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix doc
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/common/src/scalar/mod.rs                |  74 ++++++---
 datafusion/core/tests/sql/aggregates.rs            |   2 +-
 .../src/aggregate/array_agg_distinct.rs            | 171 ++++++++++-----------
 .../src/aggregate/count_distinct/mod.rs            |  10 +-
 datafusion/sqllogictest/test_files/aggregate.slt   |  57 ++++++-
 .../sqllogictest/test_files/aggregates_topk.slt    |   3 +
 6 files changed, 194 insertions(+), 123 deletions(-)

diff --git a/datafusion/common/src/scalar/mod.rs 
b/datafusion/common/src/scalar/mod.rs
index 7e53415090..80dddd8a83 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -52,7 +52,6 @@ use arrow::{
         UInt16Type, UInt32Type, UInt64Type, UInt8Type, 
DECIMAL128_MAX_PRECISION,
     },
 };
-use arrow_array::cast::as_list_array;
 use arrow_array::{ArrowNativeTypeOp, Scalar};
 
 pub use struct_builder::ScalarStructBuilder;
@@ -2138,28 +2137,67 @@ impl ScalarValue {
 
     /// Retrieve ScalarValue for each row in `array`
     ///
-    /// Example
+    /// Example 1: Array (ScalarValue::Int32)
     /// ```
     /// use datafusion_common::ScalarValue;
     /// use arrow::array::ListArray;
     /// use arrow::datatypes::{DataType, Int32Type};
     ///
+    /// // Equivalent to [[1,2,3], [4,5]]
     /// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
     ///    Some(vec![Some(1), Some(2), Some(3)]),
-    ///    None,
     ///    Some(vec![Some(4), Some(5)])
     /// ]);
     ///
+    /// // Convert the array into Scalar Values for each row
     /// let scalar_vec = 
ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
     ///
     /// let expected = vec![
-    ///   vec![
+    /// vec![
     ///     ScalarValue::Int32(Some(1)),
     ///     ScalarValue::Int32(Some(2)),
     ///     ScalarValue::Int32(Some(3)),
+    /// ],
+    /// vec![
+    ///    ScalarValue::Int32(Some(4)),
+    ///    ScalarValue::Int32(Some(5)),
+    /// ],
+    /// ];
+    ///
+    /// assert_eq!(scalar_vec, expected);
+    /// ```
+    ///
+    /// Example 2: Nested array (ScalarValue::List)
+    /// ```
+    /// use datafusion_common::ScalarValue;
+    /// use arrow::array::ListArray;
+    /// use arrow::datatypes::{DataType, Int32Type};
+    /// use datafusion_common::utils::array_into_list_array;
+    /// use std::sync::Arc;
+    ///
+    /// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+    ///    Some(vec![Some(1), Some(2), Some(3)]),
+    ///    Some(vec![Some(4), Some(5)])
+    /// ]);
+    ///
+    /// // Wrap into another layer of list, we got nested array as [ [[1,2,3], 
[4,5]] ]
+    /// let list_arr = array_into_list_array(Arc::new(list_arr));
+    ///
+    /// // Convert the array into Scalar Values for each row, we got 1D arrays 
in this example
+    /// let scalar_vec = 
ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
+    ///
+    /// let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+    ///     Some(vec![Some(1), Some(2), Some(3)]),
+    /// ]);
+    /// let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+    ///     Some(vec![Some(4), Some(5)]),
+    /// ]);
+    ///
+    /// let expected = vec![
+    ///   vec![
+    ///     ScalarValue::List(Arc::new(l1)),
+    ///     ScalarValue::List(Arc::new(l2)),
     ///   ],
-    ///   vec![],
-    ///   vec![ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5))]
     /// ];
     ///
     /// assert_eq!(scalar_vec, expected);
@@ -2168,27 +2206,13 @@ impl ScalarValue {
         let mut scalars = Vec::with_capacity(array.len());
 
         for index in 0..array.len() {
-            let scalar_values = match array.data_type() {
-                DataType::List(_) => {
-                    let list_array = as_list_array(array);
-                    match list_array.is_null(index) {
-                        true => Vec::new(),
-                        false => {
-                            let nested_array = list_array.value(index);
-                            
ScalarValue::convert_array_to_scalar_vec(&nested_array)?
-                                .into_iter()
-                                .flatten()
-                                .collect()
-                        }
-                    }
-                }
-                _ => {
-                    let scalar = ScalarValue::try_from_array(array, index)?;
-                    vec![scalar]
-                }
-            };
+            let nested_array = array.as_list::<i32>().value(index);
+            let scalar_values = (0..nested_array.len())
+                .map(|i| ScalarValue::try_from_array(&nested_array, i))
+                .collect::<Result<Vec<_>>>()?;
             scalars.push(scalar_values);
         }
+
         Ok(scalars)
     }
 
diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index af6d0d5f4e..84b791a3de 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -44,9 +44,9 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
     // We should have 1 row containing a list
     let column = actual[0].column(0);
     assert_eq!(column.len(), 1);
-
     let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?;
     let mut scalars = scalar_vec[0].clone();
+
     // workaround lack of Ord of ScalarValue
     let cmp = |a: &ScalarValue, b: &ScalarValue| {
         a.partial_cmp(b).expect("Can compare ScalarValues")
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs 
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 2e9df477d5..b073b00578 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -24,6 +24,7 @@ use std::sync::Arc;
 
 use arrow::array::ArrayRef;
 use arrow::datatypes::{DataType, Field};
+use arrow_array::cast::AsArray;
 
 use crate::aggregate::utils::down_cast_any_ref;
 use crate::expressions::format_state_name;
@@ -138,9 +139,10 @@ impl Accumulator for DistinctArrayAggAccumulator {
         assert_eq!(values.len(), 1, "batch input should only include 1 
column!");
 
         let array = &values[0];
-        let scalar_vec = ScalarValue::convert_array_to_scalar_vec(array)?;
-        for scalars in scalar_vec {
-            self.values.extend(scalars);
+
+        for i in 0..array.len() {
+            let scalar = ScalarValue::try_from_array(&array, i)?;
+            self.values.insert(scalar);
         }
 
         Ok(())
@@ -151,7 +153,12 @@ impl Accumulator for DistinctArrayAggAccumulator {
             return Ok(());
         }
 
-        self.update_batch(states)
+        let array = &states[0];
+
+        assert_eq!(array.len(), 1, "state array should only include 1 row!");
+        // Unwrap outer ListArray then do update batch
+        let inner_array = array.as_list::<i32>().value(0);
+        self.update_batch(&[inner_array])
     }
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
@@ -181,47 +188,55 @@ mod tests {
     use arrow_array::Array;
     use arrow_array::ListArray;
     use arrow_buffer::OffsetBuffer;
-    use datafusion_common::utils::array_into_list_array;
     use datafusion_common::{internal_err, DataFusionError};
 
-    // arrow::compute::sort cann't sort ListArray directly, so we need to sort 
the inner primitive array and wrap it back into ListArray.
-    fn sort_list_inner(arr: ScalarValue) -> ScalarValue {
-        let arr = match arr {
-            ScalarValue::List(arr) => arr.value(0),
-            _ => {
-                panic!("Expected ScalarValue::List, got {:?}", arr)
-            }
-        };
+    // arrow::compute::sort can't sort nested ListArray directly, so we 
compare the scalar values pair-wise.
+    fn compare_list_contents(
+        expected: Vec<ScalarValue>,
+        actual: ScalarValue,
+    ) -> Result<()> {
+        let array = actual.to_array()?;
+        let list_array = array.as_list::<i32>();
+        let inner_array = list_array.value(0);
+        let mut actual_scalars = vec![];
+        for index in 0..inner_array.len() {
+            let sv = ScalarValue::try_from_array(&inner_array, index)?;
+            actual_scalars.push(sv);
+        }
 
-        let arr = arrow::compute::sort(&arr, None).unwrap();
-        let list_arr = array_into_list_array(arr);
-        ScalarValue::List(Arc::new(list_arr))
-    }
+        if actual_scalars.len() != expected.len() {
+            return internal_err!(
+                "Expected and actual list lengths differ: expected={}, 
actual={}",
+                expected.len(),
+                actual_scalars.len()
+            );
+        }
 
-    fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> 
Result<()> {
-        let actual = sort_list_inner(actual);
-
-        match (&expected, &actual) {
-            (ScalarValue::List(arr1), ScalarValue::List(arr2)) => {
-                if arr1.eq(arr2) {
-                    Ok(())
-                } else {
-                    internal_err!(
-                        "Actual value {:?} not found in expected values {:?}",
-                        actual,
-                        expected
-                    )
+        let mut seen = vec![false; expected.len()];
+        for v in expected {
+            let mut found = false;
+            for (i, sv) in actual_scalars.iter().enumerate() {
+                if sv == &v {
+                    seen[i] = true;
+                    found = true;
+                    break;
                 }
             }
-            _ => {
-                internal_err!("Expected scalar lists as inputs")
+            if !found {
+                return internal_err!(
+                    "Expected value {:?} not found in actual values {:?}",
+                    v,
+                    actual_scalars
+                );
             }
         }
+
+        Ok(())
     }
 
     fn check_distinct_array_agg(
         input: ArrayRef,
-        expected: ScalarValue,
+        expected: Vec<ScalarValue>,
         datatype: DataType,
     ) -> Result<()> {
         let schema = Schema::new(vec![Field::new("a", datatype.clone(), 
false)]);
@@ -234,14 +249,13 @@ mod tests {
             true,
         ));
         let actual = aggregate(&batch, agg)?;
-
         compare_list_contents(expected, actual)
     }
 
     fn check_merge_distinct_array_agg(
         input1: ArrayRef,
         input2: ArrayRef,
-        expected: ScalarValue,
+        expected: Vec<ScalarValue>,
         datatype: DataType,
     ) -> Result<()> {
         let schema = Schema::new(vec![Field::new("a", datatype.clone(), 
false)]);
@@ -262,23 +276,20 @@ mod tests {
         accum1.merge_batch(&[array])?;
 
         let actual = accum1.evaluate()?;
-
         compare_list_contents(expected, actual)
     }
 
     #[test]
     fn distinct_array_agg_i32() -> Result<()> {
         let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
-        let expected =
-            ScalarValue::List(Arc::new(
-                ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![Some(vec![
-                    Some(1),
-                    Some(2),
-                    Some(4),
-                    Some(5),
-                    Some(7),
-                ])]),
-            ));
+
+        let expected = vec![
+            ScalarValue::Int32(Some(1)),
+            ScalarValue::Int32(Some(2)),
+            ScalarValue::Int32(Some(4)),
+            ScalarValue::Int32(Some(5)),
+            ScalarValue::Int32(Some(7)),
+        ];
 
         check_distinct_array_agg(col, expected, DataType::Int32)
     }
@@ -288,18 +299,15 @@ mod tests {
         let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 
2]));
         let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4]));
 
-        let expected =
-            ScalarValue::List(Arc::new(
-                ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![Some(vec![
-                    Some(1),
-                    Some(2),
-                    Some(3),
-                    Some(4),
-                    Some(5),
-                    Some(7),
-                    Some(8),
-                ])]),
-            ));
+        let expected = vec![
+            ScalarValue::Int32(Some(1)),
+            ScalarValue::Int32(Some(2)),
+            ScalarValue::Int32(Some(3)),
+            ScalarValue::Int32(Some(4)),
+            ScalarValue::Int32(Some(5)),
+            ScalarValue::Int32(Some(7)),
+            ScalarValue::Int32(Some(8)),
+        ];
 
         check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32)
     }
@@ -351,23 +359,16 @@ mod tests {
         let l2 = ScalarValue::List(Arc::new(l2));
         let l3 = ScalarValue::List(Arc::new(l3));
 
-        // Duplicate l1 in the input array and check that it is deduped in the 
output.
-        let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, 
l1]).unwrap();
-
-        let expected =
-            ScalarValue::List(Arc::new(
-                ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![Some(vec![
-                    Some(1),
-                    Some(2),
-                    Some(3),
-                    Some(4),
-                    Some(5),
-                    Some(6),
-                    Some(7),
-                    Some(8),
-                    Some(9),
-                ])]),
-            ));
+        // Duplicate l1 and l3 in the input array and check that it is deduped 
in the output.
+        let array = ScalarValue::iter_to_array(vec![
+            l1.clone(),
+            l2.clone(),
+            l3.clone(),
+            l3.clone(),
+            l1.clone(),
+        ])
+        .unwrap();
+        let expected = vec![l1, l2, l3];
 
         check_distinct_array_agg(
             array,
@@ -426,22 +427,10 @@ mod tests {
         let l3 = ScalarValue::List(Arc::new(l3));
 
         // Duplicate l1 in the input array and check that it is deduped in the 
output.
-        let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap();
-        let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap();
-
-        let expected =
-            ScalarValue::List(Arc::new(
-                ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![Some(vec![
-                    Some(1),
-                    Some(2),
-                    Some(3),
-                    Some(4),
-                    Some(5),
-                    Some(6),
-                    Some(7),
-                    Some(8),
-                ])]),
-            ));
+        let input1 = ScalarValue::iter_to_array(vec![l1.clone(), 
l2.clone()]).unwrap();
+        let input2 = ScalarValue::iter_to_array(vec![l1.clone(), 
l3.clone()]).unwrap();
+
+        let expected = vec![l1, l2, l3];
 
         check_merge_distinct_array_agg(input1, input2, expected, 
DataType::Int32)
     }
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
index 52afd82d03..71782fcc5f 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
@@ -26,6 +26,7 @@ use std::sync::Arc;
 use ahash::RandomState;
 use arrow::array::{Array, ArrayRef};
 use arrow::datatypes::{DataType, Field, TimeUnit};
+use arrow_array::cast::AsArray;
 use arrow_array::types::{
     Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, 
Float32Type,
     Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, 
Time32MillisecondType,
@@ -250,11 +251,10 @@ impl Accumulator for DistinctCountAccumulator {
             return Ok(());
         }
         assert_eq!(states.len(), 1, "array_agg states must be singleton!");
-        let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
-        for scalars in scalar_vec.into_iter() {
-            self.values.extend(scalars);
-        }
-        Ok(())
+        let array = &states[0];
+        let list_array = array.as_list::<i32>();
+        let inner_array = list_array.value(0);
+        self.update_batch(&[inner_array])
     }
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index fdd70a80ac..109c64f060 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -139,6 +139,61 @@ AggregateExec: mode=Final, gby=[], 
aggr=[ARRAY_AGG(agg_order.c1)]
 --------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
 ----------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, 
projection=[c1, c2, c3], has_header=true
 
+# test array_agg_order with list data type
+statement ok
+CREATE TABLE array_agg_order_list_table AS VALUES
+  ('w', 2, [1,2,3], 10),
+  ('w', 1, [9,5,2], 20),
+  ('w', 1, [3,2,5], 30),
+  ('b', 2, [4,5,6], 20),
+  ('b', 1, [7,8,9], 30)
+;
+
+query T? rowsort
+select column1, array_agg(column3 order by column2, column4 desc) from 
array_agg_order_list_table group by column1;
+----
+b [[7, 8, 9], [4, 5, 6]]
+w [[3, 2, 5], [9, 5, 2], [1, 2, 3]]
+
+query T?? rowsort
+select column1, first_value(column3 order by column2, column4 desc), 
last_value(column3 order by column2, column4 desc) from 
array_agg_order_list_table group by column1;
+----
+b [7, 8, 9] [4, 5, 6]
+w [3, 2, 5] [1, 2, 3]
+
+query T? rowsort
+select column1, nth_value(column3, 2 order by column2, column4 desc) from 
array_agg_order_list_table group by column1;
+----
+b [4, 5, 6]
+w [9, 5, 2]
+
+statement ok
+drop table array_agg_order_list_table;
+
+# test array_agg_distinct with list data type
+statement ok
+CREATE TABLE array_agg_distinct_list_table AS VALUES
+  ('w', [0,1]),
+  ('w', [0,1]),
+  ('w', [1,0]),
+  ('b', [1,0]),
+  ('b', [1,0]),
+  ('b', [1,0]),
+  ('b', [0,1])
+;
+
+# Apply array_sort to have determinisitic result, higher dimension nested 
array also works but not for array sort,
+# so they are covered in 
`datafusion/physical-expr/src/aggregate/array_agg_distinct.rs`
+query ??
+select array_sort(c1), array_sort(c2) from (
+  select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 
from array_agg_distinct_list_table
+);
+----
+[b, w] [[0, 1], [1, 0]]
+
+statement ok
+drop table array_agg_distinct_list_table;
+
 statement error This feature is not implemented: LIMIT not supported in 
ARRAY_AGG: 1
 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100
 
@@ -3259,4 +3314,4 @@ SELECT 0 AS "t.a" FROM t HAVING MAX(t.a) = 0;
 ----
 
 statement ok
-DROP TABLE t;
\ No newline at end of file
+DROP TABLE t;
diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt 
b/datafusion/sqllogictest/test_files/aggregates_topk.slt
index bd8f00e041..3f139ede8c 100644
--- a/datafusion/sqllogictest/test_files/aggregates_topk.slt
+++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt
@@ -212,3 +212,6 @@ b 0 -2
 a -1 -1
 NULL 0 0
 c 1 2
+
+statement ok
+drop table traces;

Reply via email to