This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a0b0221d2e Change `flatten` so it does only a level, not recursively 
(#15160)
a0b0221d2e is described below

commit a0b0221d2e3b29e7e83492bcb09e307e14a062f8
Author: delamarch3 <[email protected]>
AuthorDate: Thu Apr 17 17:54:02 2025 +0100

    Change `flatten` so it does only a level, not recursively (#15160)
    
    * flatten array in a single step instead of recursive
    
    * clippy
    
    * update flatten type signature to Array
    
    * add fixed list to list coercion to flatten signature
    
    * support LargeList(List) and LargeList(FixedSizeList) in flatten
    
    * add test for LargeList(FixedSizeList)
    
    * handle nulls
    
    * uncomment flatten(NULL) test - it already works
---
 datafusion/functions-nested/src/flatten.rs   | 153 ++++++++++++++++-----------
 datafusion/sqllogictest/test_files/array.slt |  39 ++++---
 2 files changed, 115 insertions(+), 77 deletions(-)

diff --git a/datafusion/functions-nested/src/flatten.rs 
b/datafusion/functions-nested/src/flatten.rs
index f288035948..4279f04e3d 100644
--- a/datafusion/functions-nested/src/flatten.rs
+++ b/datafusion/functions-nested/src/flatten.rs
@@ -18,19 +18,18 @@
 //! [`ScalarUDFImpl`] definitions for flatten function.
 
 use crate::utils::make_scalar_function;
-use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait};
+use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait};
 use arrow::buffer::OffsetBuffer;
 use arrow::datatypes::{
     DataType,
     DataType::{FixedSizeList, LargeList, List, Null},
 };
-use datafusion_common::cast::{
-    as_generic_list_array, as_large_list_array, as_list_array,
-};
+use datafusion_common::cast::{as_large_list_array, as_list_array};
+use datafusion_common::utils::ListCoercion;
 use datafusion_common::{exec_err, utils::take_function_args, Result};
 use datafusion_expr::{
-    ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, 
Signature,
-    TypeSignature, Volatility,
+    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, 
Documentation,
+    ScalarUDFImpl, Signature, TypeSignature, Volatility,
 };
 use datafusion_macros::user_doc;
 use std::any::Any;
@@ -77,9 +76,11 @@ impl Flatten {
     pub fn new() -> Self {
         Self {
             signature: Signature {
-                // TODO (https://github.com/apache/datafusion/issues/13757) 
flatten should be single-step, not recursive
                 type_signature: TypeSignature::ArraySignature(
-                    ArrayFunctionSignature::RecursiveArray,
+                    ArrayFunctionSignature::Array {
+                        arguments: vec![ArrayFunctionArgument::Array],
+                        array_coercion: 
Some(ListCoercion::FixedSizedListToList),
+                    },
                 ),
                 volatility: Volatility::Immutable,
             },
@@ -102,25 +103,23 @@ impl ScalarUDFImpl for Flatten {
     }
 
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        fn get_base_type(data_type: &DataType) -> Result<DataType> {
-            match data_type {
-                List(field) | FixedSizeList(field, _)
-                    if matches!(field.data_type(), List(_) | FixedSizeList(_, 
_)) =>
-                {
-                    get_base_type(field.data_type())
-                }
-                LargeList(field) if matches!(field.data_type(), LargeList(_)) 
=> {
-                    get_base_type(field.data_type())
+        let data_type = match &arg_types[0] {
+            List(field) | FixedSizeList(field, _) => match field.data_type() {
+                List(field) | FixedSizeList(field, _) => 
List(Arc::clone(field)),
+                _ => arg_types[0].clone(),
+            },
+            LargeList(field) => match field.data_type() {
+                List(field) | LargeList(field) | FixedSizeList(field, _) => {
+                    LargeList(Arc::clone(field))
                 }
-                Null | List(_) | LargeList(_) => Ok(data_type.to_owned()),
-                FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
-                _ => exec_err!(
-                    "Not reachable, data_type should be List, LargeList or 
FixedSizeList"
-                ),
-            }
-        }
+                _ => arg_types[0].clone(),
+            },
+            Null => Null,
+            _ => exec_err!(
+                "Not reachable, data_type should be List, LargeList or 
FixedSizeList"
+            )?,
+        };
 
-        let data_type = get_base_type(&arg_types[0])?;
         Ok(data_type)
     }
 
@@ -146,14 +145,62 @@ pub fn flatten_inner(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
     match array.data_type() {
         List(_) => {
-            let list_arr = as_list_array(&array)?;
-            let flattened_array = flatten_internal::<i32>(list_arr.clone(), 
None)?;
-            Ok(Arc::new(flattened_array) as ArrayRef)
+            let (field, offsets, values, nulls) =
+                as_list_array(&array)?.clone().into_parts();
+
+            match field.data_type() {
+                List(_) => {
+                    let (inner_field, inner_offsets, inner_values, _) =
+                        as_list_array(&values)?.clone().into_parts();
+                    let offsets = 
get_offsets_for_flatten::<i32>(inner_offsets, offsets);
+                    let flattened_array = GenericListArray::<i32>::new(
+                        inner_field,
+                        offsets,
+                        inner_values,
+                        nulls,
+                    );
+
+                    Ok(Arc::new(flattened_array) as ArrayRef)
+                }
+                LargeList(_) => {
+                    exec_err!("flatten does not support type '{:?}'", 
array.data_type())?
+                }
+                _ => Ok(Arc::clone(array) as ArrayRef),
+            }
         }
         LargeList(_) => {
-            let list_arr = as_large_list_array(&array)?;
-            let flattened_array = flatten_internal::<i64>(list_arr.clone(), 
None)?;
-            Ok(Arc::new(flattened_array) as ArrayRef)
+            let (field, offsets, values, nulls) =
+                as_large_list_array(&array)?.clone().into_parts();
+
+            match field.data_type() {
+                List(_) => {
+                    let (inner_field, inner_offsets, inner_values, _) =
+                        as_list_array(&values)?.clone().into_parts();
+                    let offsets = get_large_offsets_for_flatten(inner_offsets, 
offsets);
+                    let flattened_array = GenericListArray::<i64>::new(
+                        inner_field,
+                        offsets,
+                        inner_values,
+                        nulls,
+                    );
+
+                    Ok(Arc::new(flattened_array) as ArrayRef)
+                }
+                LargeList(_) => {
+                    let (inner_field, inner_offsets, inner_values, nulls) =
+                        as_large_list_array(&values)?.clone().into_parts();
+                    let offsets = 
get_offsets_for_flatten::<i64>(inner_offsets, offsets);
+                    let flattened_array = GenericListArray::<i64>::new(
+                        inner_field,
+                        offsets,
+                        inner_values,
+                        nulls,
+                    );
+
+                    Ok(Arc::new(flattened_array) as ArrayRef)
+                }
+                _ => Ok(Arc::clone(array) as ArrayRef),
+            }
         }
         Null => Ok(Arc::clone(array)),
         _ => {
@@ -162,37 +209,6 @@ pub fn flatten_inner(args: &[ArrayRef]) -> 
Result<ArrayRef> {
     }
 }
 
-fn flatten_internal<O: OffsetSizeTrait>(
-    list_arr: GenericListArray<O>,
-    indexes: Option<OffsetBuffer<O>>,
-) -> Result<GenericListArray<O>> {
-    let (field, offsets, values, _) = list_arr.clone().into_parts();
-    let data_type = field.data_type();
-
-    match data_type {
-        // Recursively get the base offsets for flattened array
-        List(_) | LargeList(_) => {
-            let sub_list = as_generic_list_array::<O>(&values)?;
-            if let Some(indexes) = indexes {
-                let offsets = get_offsets_for_flatten(offsets, indexes);
-                flatten_internal::<O>(sub_list.clone(), Some(offsets))
-            } else {
-                flatten_internal::<O>(sub_list.clone(), Some(offsets))
-            }
-        }
-        // Reach the base level, create a new list array
-        _ => {
-            if let Some(indexes) = indexes {
-                let offsets = get_offsets_for_flatten(offsets, indexes);
-                let list_arr = GenericListArray::<O>::new(field, offsets, 
values, None);
-                Ok(list_arr)
-            } else {
-                Ok(list_arr)
-            }
-        }
-    }
-}
-
 // Create new offsets that are equivalent to `flatten` the array.
 fn get_offsets_for_flatten<O: OffsetSizeTrait>(
     offsets: OffsetBuffer<O>,
@@ -205,3 +221,16 @@ fn get_offsets_for_flatten<O: OffsetSizeTrait>(
         .collect();
     OffsetBuffer::new(offsets.into())
 }
+
+// Create new large offsets that are equivalent to `flatten` the array.
+fn get_large_offsets_for_flatten<O: OffsetSizeTrait, P: OffsetSizeTrait>(
+    offsets: OffsetBuffer<O>,
+    indexes: OffsetBuffer<P>,
+) -> OffsetBuffer<i64> {
+    let buffer = offsets.into_inner();
+    let offsets: Vec<i64> = indexes
+        .iter()
+        .map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap())
+        .collect();
+    OffsetBuffer::new(offsets.into())
+}
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index e780d6c8b2..9772de3db3 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -7285,12 +7285,10 @@ select array_concat(column1, [7]) from arrays_values_v2;
 
 # flatten
 
-#TODO: https://github.com/apache/datafusion/issues/7142
-# follow DuckDB
-#query ?
-#select flatten(NULL);
-#----
-#NULL
+query ?
+select flatten(NULL);
+----
+NULL
 
 # flatten with scalar values #1
 query ???
@@ -7298,21 +7296,21 @@ select flatten(make_array(1, 2, 1, 3, 2)),
        flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
        flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
 ----
-[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
+[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]
 
 query ???
 select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')),
        flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 
5)), 'LargeList(LargeList(Int64))')),
        flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 
'LargeList(LargeList(LargeList(Float64)))'));
 ----
-[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
+[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]
 
 query ???
 select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'FixedSizeList(5, 
Int64)')),
        flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 
5)), 'FixedSizeList(4, List(Int64))')),
        flatten(arrow_cast(make_array([[1.1], [2.2]], [[3.3], [4.4]]), 
'FixedSizeList(2, List(List(Float64)))'));
 ----
-[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
+[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]
 
 # flatten with column values
 query ????
@@ -7322,8 +7320,8 @@ select flatten(column1),
        flatten(column4)
 from flatten_table;
 ----
-[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
-[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 
3.4]
+[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
 
 query ????
 select flatten(column1),
@@ -7332,8 +7330,8 @@ select flatten(column1),
        flatten(column4)
 from large_flatten_table;
 ----
-[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
-[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
+[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 
3.4]
+[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
 
 query ????
 select flatten(column1),
@@ -7342,8 +7340,19 @@ select flatten(column1),
        flatten(column4)
 from fixed_size_flatten_table;
 ----
-[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
-[1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 
6.0]
+[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 
3.4]
+[1, 2, 3, 4, 5, 6] [[8], [9, 10], [11, 12, 13]] [[[1, 2]], [[3]]] [1.0, 2.0, 
3.0, 4.0, 5.0, 6.0]
+
+# flatten with different inner list type
+query ??????
+select flatten(arrow_cast(make_array([1, 2], [3, 4]), 'List(FixedSizeList(2, 
Int64))')),
+       flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 
'List(FixedSizeList(1, List(Int64)))')),
+       flatten(arrow_cast(make_array([1, 2], [3, 4]), 
'LargeList(List(Int64))')),
+       flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 
'LargeList(List(List(Int64)))')),
+       flatten(arrow_cast(make_array([1, 2], [3, 4]), 
'LargeList(FixedSizeList(2, Int64))')),
+       flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 
'LargeList(FixedSizeList(1, List(Int64)))'))
+----
+[1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 
2], [3, 4]]
 
 ## empty (aliases: `array_empty`, `list_empty`)
 # empty scalar function #1


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to