Jefffrey commented on code in PR #18137:
URL: https://github.com/apache/datafusion/pull/18137#discussion_r2442730061


##########
datafusion/functions/src/string/concat.rs:
##########
@@ -90,23 +191,124 @@ impl ScalarUDFImpl for ConcatFunc {
 
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
         use DataType::*;
-        let mut dt = &Utf8;
-        arg_types.iter().for_each(|data_type| {
-            if data_type == &Utf8View {
-                dt = data_type;
+        
+        // Check if any argument is an array type
+        let has_array = arg_types.iter().any(|dt| {
+            matches!(dt, List(_) | LargeList(_) | FixedSizeList(_, _))
+        });
+        
+        if has_array {
+            // Use array_concat-style return type logic
+            use datafusion_common::utils::{base_type, list_ndims};
+            use datafusion_expr::binary::type_union_resolution;
+            
+            let mut max_dims = 0;
+            let mut large_list = false;
+            let mut element_types = Vec::with_capacity(arg_types.len());
+            
+            for arg_type in arg_types {
+                match arg_type {
+                    Null | List(_) | FixedSizeList(..) => (),
+                    LargeList(_) => large_list = true,
+                    _ => {
+                        return plan_err!("concat does not support type 
{arg_type}")
+                    }
+                }
+
+                max_dims = max_dims.max(list_ndims(arg_type));
+                element_types.push(base_type(arg_type))
             }
-            if data_type == &LargeUtf8 && dt != &Utf8View {
-                dt = data_type;
+
+            if max_dims == 0 {
+                Ok(Null)
+            } else {
+                // Handle the case where we have empty arrays (Null element 
types)
+                // Filter out Null types and find the first non-null type
+                let non_null_types: Vec<_> = element_types.iter().filter(|t| 
**t != Null).collect();
+                
+                let unified_element_type = if non_null_types.is_empty() {
+                    // All arrays are empty, use Int32 as default
+                    Int32
+                } else if non_null_types.len() == 1 {
+                    // Only one non-null type
+                    non_null_types[0].clone()
+                } else {
+                    // Multiple non-null types, try to unify them
+                    if let Some(unified) = 
type_union_resolution(&non_null_types.into_iter().cloned().collect::<Vec<_>>()) 
{
+                        unified
+                    } else {
+                        return plan_err!(
+                            "Failed to unify argument types of concat: [{}]",
+                            arg_types.iter().map(|t| 
format!("{t}")).collect::<Vec<_>>().join(", ")
+                        );
+                    }
+                };
+                
+                // Build the return type
+                let mut return_type = unified_element_type;
+                for _ in 1..max_dims {
+                    return_type = DataType::new_list(return_type, true)
+                }
+
+                if large_list {
+                    Ok(DataType::new_large_list(return_type, true))
+                } else {
+                    Ok(DataType::new_list(return_type, true))
+                }
             }
-        });
+        } else {
+            // Original string concatenation logic
+            let mut dt = &Utf8;
+            arg_types.iter().for_each(|data_type| {
+                if data_type == &Utf8View {
+                    dt = data_type;
+                }
+                if data_type == &LargeUtf8 && dt != &Utf8View {
+                    dt = data_type;
+                }
+            });
 
-        Ok(dt.to_owned())
+            Ok(dt.to_owned())
+        }
     }
 
     /// Concatenates the text representations of all the arguments. NULL 
arguments are ignored.
     /// concat('abcde', 2, NULL, 22) = 'abcde222'
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
         let ScalarFunctionArgs { args, .. } = args;
+        
+        // Check if any argument is an array type
+        let has_array = args.iter().any(|arg| {
+            matches!(
+                arg.data_type(),
+                DataType::List(_) | DataType::LargeList(_) | 
DataType::FixedSizeList(_, _)
+            )
+        });
+        
+        if has_array {
+            // Determine the number of rows for array operations
+            let num_rows = args
+                .iter()
+                .filter_map(|arg| match arg {
+                    ColumnarValue::Array(array) => Some(array.len()),
+                    _ => None,
+                })
+                .next()
+                .unwrap_or(1);

Review Comment:
   This is should be available via `number_rows` in `ScalarFunctionArgs`



##########
datafusion/functions/src/string/concat.rs:
##########
@@ -90,23 +191,124 @@ impl ScalarUDFImpl for ConcatFunc {
 
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
         use DataType::*;
-        let mut dt = &Utf8;
-        arg_types.iter().for_each(|data_type| {
-            if data_type == &Utf8View {
-                dt = data_type;
+        
+        // Check if any argument is an array type
+        let has_array = arg_types.iter().any(|dt| {
+            matches!(dt, List(_) | LargeList(_) | FixedSizeList(_, _))
+        });

Review Comment:
   Is this mirroring whats done in `coerce_types`? We shouldn't duplicate the 
logic as the argument inputs to `return_type` are already coerced



##########
datafusion/functions/src/string/concat.rs:
##########
@@ -90,23 +191,124 @@ impl ScalarUDFImpl for ConcatFunc {
 
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
         use DataType::*;
-        let mut dt = &Utf8;
-        arg_types.iter().for_each(|data_type| {
-            if data_type == &Utf8View {
-                dt = data_type;
+        
+        // Check if any argument is an array type
+        let has_array = arg_types.iter().any(|dt| {
+            matches!(dt, List(_) | LargeList(_) | FixedSizeList(_, _))
+        });
+        
+        if has_array {
+            // Use array_concat-style return type logic
+            use datafusion_common::utils::{base_type, list_ndims};
+            use datafusion_expr::binary::type_union_resolution;
+            
+            let mut max_dims = 0;
+            let mut large_list = false;
+            let mut element_types = Vec::with_capacity(arg_types.len());
+            
+            for arg_type in arg_types {
+                match arg_type {
+                    Null | List(_) | FixedSizeList(..) => (),
+                    LargeList(_) => large_list = true,
+                    _ => {
+                        return plan_err!("concat does not support type 
{arg_type}")
+                    }
+                }
+
+                max_dims = max_dims.max(list_ndims(arg_type));
+                element_types.push(base_type(arg_type))
             }
-            if data_type == &LargeUtf8 && dt != &Utf8View {
-                dt = data_type;
+
+            if max_dims == 0 {
+                Ok(Null)
+            } else {
+                // Handle the case where we have empty arrays (Null element 
types)
+                // Filter out Null types and find the first non-null type
+                let non_null_types: Vec<_> = element_types.iter().filter(|t| 
**t != Null).collect();
+                
+                let unified_element_type = if non_null_types.is_empty() {
+                    // All arrays are empty, use Int32 as default
+                    Int32
+                } else if non_null_types.len() == 1 {
+                    // Only one non-null type
+                    non_null_types[0].clone()
+                } else {
+                    // Multiple non-null types, try to unify them
+                    if let Some(unified) = 
type_union_resolution(&non_null_types.into_iter().cloned().collect::<Vec<_>>()) 
{
+                        unified
+                    } else {
+                        return plan_err!(
+                            "Failed to unify argument types of concat: [{}]",
+                            arg_types.iter().map(|t| 
format!("{t}")).collect::<Vec<_>>().join(", ")
+                        );
+                    }
+                };
+                
+                // Build the return type
+                let mut return_type = unified_element_type;
+                for _ in 1..max_dims {
+                    return_type = DataType::new_list(return_type, true)
+                }
+
+                if large_list {
+                    Ok(DataType::new_large_list(return_type, true))
+                } else {
+                    Ok(DataType::new_list(return_type, true))
+                }
             }
-        });
+        } else {
+            // Original string concatenation logic
+            let mut dt = &Utf8;
+            arg_types.iter().for_each(|data_type| {
+                if data_type == &Utf8View {
+                    dt = data_type;
+                }
+                if data_type == &LargeUtf8 && dt != &Utf8View {
+                    dt = data_type;
+                }
+            });
 
-        Ok(dt.to_owned())
+            Ok(dt.to_owned())
+        }
     }
 
     /// Concatenates the text representations of all the arguments. NULL 
arguments are ignored.
     /// concat('abcde', 2, NULL, 22) = 'abcde222'
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
         let ScalarFunctionArgs { args, .. } = args;
+        
+        // Check if any argument is an array type
+        let has_array = args.iter().any(|arg| {
+            matches!(
+                arg.data_type(),
+                DataType::List(_) | DataType::LargeList(_) | 
DataType::FixedSizeList(_, _)
+            )
+        });
+        
+        if has_array {
+            // Determine the number of rows for array operations
+            let num_rows = args
+                .iter()
+                .filter_map(|arg| match arg {
+                    ColumnarValue::Array(array) => Some(array.len()),
+                    _ => None,
+                })
+                .next()
+                .unwrap_or(1);
+            
+            // Convert to ArrayRef and delegate to array_concat_inner

Review Comment:
   Can we remove these LLM-like comments that don't provide much benefit but 
just add verbosity?
   
   Actually in this case it's wrong because it isn't delegating to 
`array_concat_inner` 🤔 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to