This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 195f8256ce feat: support `LargeList` in `flatten` (#9110)
195f8256ce is described below

commit 195f8256ce5414844c297601f145757f87532edb
Author: Alex Huang <[email protected]>
AuthorDate: Mon Feb 5 20:25:03 2024 +0800

    feat: support `LargeList` in `flatten` (#9110)
    
    * support FixedSizeList in flatten
    
    * Refactor flatten function and add test cases
    
    * remove redundant tests
---
 datafusion/expr/src/built_in_function.rs          |  9 ++--
 datafusion/physical-expr/src/array_expressions.rs | 52 ++++++++++++++++-------
 datafusion/sqllogictest/test_files/array.slt      | 42 ++++++++++++++++--
 3 files changed, 78 insertions(+), 25 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index b1b74c1628..4cdf0c4a11 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -547,11 +547,10 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::Flatten => {
                 fn get_base_type(data_type: &DataType) -> Result<DataType> {
                     match data_type {
-                        DataType::List(field) => match field.data_type() {
-                            DataType::List(_) => 
get_base_type(field.data_type()),
-                            _ => Ok(data_type.to_owned()),
-                        },
-                        _ => internal_err!("Not reachable, data_type should be 
List"),
+                        DataType::List(field) if matches!(field.data_type(), 
DataType::List(_)) => get_base_type(field.data_type()),
+                        DataType::LargeList(field) if 
matches!(field.data_type(), DataType::LargeList(_)) => 
get_base_type(field.data_type()),
+                        DataType::Null | DataType::List(_) | 
DataType::LargeList(_) => Ok(data_type.to_owned()),
+                        _ => internal_err!("Not reachable, data_type should be 
List or LargeList"),
                     }
                 }
 
diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index 844dae0917..0709e66a35 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -2246,38 +2246,41 @@ fn generic_list_cardinality<O: OffsetSizeTrait>(
 }
 
 // Create new offsets that are euqiavlent to `flatten` the array.
-fn get_offsets_for_flatten(
-    offsets: OffsetBuffer<i32>,
-    indexes: OffsetBuffer<i32>,
-) -> OffsetBuffer<i32> {
+fn get_offsets_for_flatten<O: OffsetSizeTrait>(
+    offsets: OffsetBuffer<O>,
+    indexes: OffsetBuffer<O>,
+) -> OffsetBuffer<O> {
     let buffer = offsets.into_inner();
-    let offsets: Vec<i32> = indexes.iter().map(|i| buffer[*i as 
usize]).collect();
+    let offsets: Vec<O> = indexes
+        .iter()
+        .map(|i| buffer[i.to_usize().unwrap()])
+        .collect();
     OffsetBuffer::new(offsets.into())
 }
 
-fn flatten_internal(
-    array: &dyn Array,
-    indexes: Option<OffsetBuffer<i32>>,
-) -> Result<ListArray> {
-    let list_arr = as_list_array(array)?;
+fn flatten_internal<O: OffsetSizeTrait>(
+    list_arr: GenericListArray<O>,
+    indexes: Option<OffsetBuffer<O>>,
+) -> Result<GenericListArray<O>> {
     let (field, offsets, values, _) = list_arr.clone().into_parts();
     let data_type = field.data_type();
 
     match data_type {
         // Recursively get the base offsets for flattened array
-        DataType::List(_) => {
+        DataType::List(_) | DataType::LargeList(_) => {
+            let sub_list = as_generic_list_array::<O>(&values)?;
             if let Some(indexes) = indexes {
                 let offsets = get_offsets_for_flatten(offsets, indexes);
-                flatten_internal(&values, Some(offsets))
+                flatten_internal::<O>(sub_list.clone(), Some(offsets))
             } else {
-                flatten_internal(&values, Some(offsets))
+                flatten_internal::<O>(sub_list.clone(), Some(offsets))
             }
         }
         // Reach the base level, create a new list array
         _ => {
             if let Some(indexes) = indexes {
                 let offsets = get_offsets_for_flatten(offsets, indexes);
-                let list_arr = ListArray::new(field, offsets, values, None);
+                let list_arr = GenericListArray::<O>::new(field, offsets, 
values, None);
                 Ok(list_arr)
             } else {
                 Ok(list_arr.clone())
@@ -2292,8 +2295,25 @@ pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
         return exec_err!("flatten expects one argument");
     }
 
-    let flattened_array = flatten_internal(&args[0], None)?;
-    Ok(Arc::new(flattened_array) as ArrayRef)
+    let array_type = args[0].data_type();
+    match array_type {
+        DataType::List(_) => {
+            let list_arr = as_list_array(&args[0])?;
+            let flattened_array = flatten_internal::<i32>(list_arr.clone(), 
None)?;
+            Ok(Arc::new(flattened_array) as ArrayRef)
+        }
+        DataType::LargeList(_) => {
+            let list_arr = as_large_list_array(&args[0])?;
+            let flattened_array = flatten_internal::<i64>(list_arr.clone(), 
None)?;
+            Ok(Arc::new(flattened_array) as ArrayRef)
+        }
+        DataType::Null => Ok(args[0].clone()),
+        _ => {
+            exec_err!("flatten does not support type '{array_type:?}'")
+        }
+    }
+
+    // Ok(Arc::new(flattened_array) as ArrayRef)
 }
 
 /// Dispatch array length computation based on the offset type.
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 4fdc428d7a..36a656eb7f 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -202,6 +202,17 @@ AS VALUES
   (make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), 
make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]))
 ;
 
+statement ok
+CREATE TABLE large_flatten_table
+AS
+  SELECT
+    arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1,
+    arrow_cast(column2, 'LargeList(LargeList(LargeList(Int64)))') AS column2,
+    arrow_cast(column3, 'LargeList(LargeList(LargeList(LargeList(Int64))))') 
AS column3,
+    arrow_cast(column4, 'LargeList(LargeList(Float64))') AS column4
+  FROM flatten_table
+;
+
 statement ok
 CREATE TABLE array_has_table_1D
 AS VALUES
@@ -5345,6 +5356,13 @@ select array_concat(column1, [7]) from arrays_values_v2;
 [7]
 
 # flatten
+# follow DuckDB
+query ?
+select flatten(NULL);
+----
+NULL
+
+# flatten with scalar values #1
 query ???
 select flatten(make_array(1, 2, 1, 3, 2)),
        flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
@@ -5352,12 +5370,14 @@ select flatten(make_array(1, 2, 1, 3, 2)),
 ----
 [1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]
 
-query ????
-select column1, column2, column3, column4 from flatten_table;
+query ???
+select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')),
+       flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 
5)), 'LargeList(LargeList(Int64))')),
+       flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 
'LargeList(LargeList(LargeList(Float64)))'));
 ----
-[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], 
[2.1, 2.2], [3.2, 3.3, 3.4]]
-[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 
4.0], [5.0, 6.0]]
+[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]
 
+# flatten with column values
 query ????
 select flatten(column1),
        flatten(column2),
@@ -5368,6 +5388,17 @@ from flatten_table;
 [1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
 [1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
 
+query ????
+select flatten(column1),
+       flatten(column2),
+       flatten(column3),
+       flatten(column4)
+from large_flatten_table;
+----
+[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
+[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+
+## empty
 # empty scalar function #1
 query B
 select empty(make_array(1));
@@ -5746,6 +5777,9 @@ drop table 
fixed_size_nested_arrays_with_repeating_elements;
 statement ok
 drop table flatten_table;
 
+statement ok
+drop table large_flatten_table;
+
 statement ok
 drop table arrays_values_without_nulls;
 

Reply via email to