Jefffrey commented on code in PR #18137:
URL: https://github.com/apache/datafusion/pull/18137#discussion_r2659672222
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -501,4 +645,120 @@ mod tests {
}
Ok(())
}
+
+ #[test]
+ fn test_concat_with_integers() -> Result<()> {
Review Comment:
I don't see how these tests are related to the original goal of adding array
concat support to existing string concat?
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -88,37 +143,91 @@ impl ScalarUDFImpl for ConcatFunc {
&self.signature
}
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ use DataType::*;
+
+ if arg_types.is_empty() {
+ return plan_err!("concat requires at least one argument");
+ }
+
+ let has_arrays = arg_types
+ .iter()
+ .any(|dt| matches!(dt, List(_) | LargeList(_) | FixedSizeList(_,
_)));
+ let has_non_arrays = arg_types
+ .iter()
+ .any(|dt| !matches!(dt, List(_) | LargeList(_) | FixedSizeList(_,
_) | Null));
+
+ if has_arrays && has_non_arrays {
+ return plan_err!(
+ "Cannot mix array and non-array arguments in concat function."
+ );
+ }
+
+ if has_arrays {
+ return Ok(arg_types.to_vec());
+ }
+
+ let target_type = self.get_string_type_precedence(arg_types);
+
+ // Only coerce types that need coercion, keep string types as-is
+ let coerced_types = arg_types
+ .iter()
+ .map(|data_type| match data_type {
+ Utf8View | Utf8 | LargeUtf8 => data_type.clone(),
+ _ => target_type.clone(),
+ })
+ .collect();
+ Ok(coerced_types)
+ }
+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
- let mut dt = &Utf8;
- arg_types.iter().for_each(|data_type| {
- if data_type == &Utf8View {
- dt = data_type;
- }
- if data_type == &LargeUtf8 && dt != &Utf8View {
- dt = data_type;
- }
- });
- Ok(dt.to_owned())
+ if arg_types.is_empty() {
+ return plan_err!("concat requires at least one argument");
+ }
+
+ // After coercion, all arguments have the same type category, so check
only the first
+ if let List(field) | LargeList(field) | FixedSizeList(field, _) =
&arg_types[0] {
+ return Ok(List(Arc::new(arrow::datatypes::Field::new(
+ "item",
+ field.data_type().clone(),
+ true,
+ ))));
+ }
+
+ // For non-array arguments, return string type based on precedence
+ let dt = self.get_string_type_precedence(arg_types);
+ Ok(dt)
}
/// Concatenates the text representations of all the arguments. NULL
arguments are ignored.
/// concat('abcde', 2, NULL, 22) = 'abcde222'
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
+ use DataType::*;
let ScalarFunctionArgs { args, .. } = args;
- let mut return_datatype = DataType::Utf8;
- args.iter().for_each(|col| {
- if col.data_type() == DataType::Utf8View {
- return_datatype = col.data_type();
- }
- if col.data_type() == DataType::LargeUtf8
- && return_datatype != DataType::Utf8View
- {
- return_datatype = col.data_type();
- }
- });
+ if args.is_empty() {
+ return plan_err!("concat requires at least one argument");
+ }
+
+ // After coercion, all arguments have the same type category, so check
only the first
+ let is_array = match &args[0] {
+ ColumnarValue::Array(array) => matches!(
+ array.data_type(),
+ List(_) | LargeList(_) | FixedSizeList(_, _)
+ ),
+ ColumnarValue::Scalar(scalar) => matches!(
+ scalar.data_type(),
+ List(_) | LargeList(_) | FixedSizeList(_, _)
+ ),
+ };
+ if is_array {
+ return self.concat_arrays(&args);
+ }
+
+ let data_types: Vec<DataType> = args.iter().map(|col|
col.data_type()).collect();
+ let return_datatype = self.get_string_type_precedence(&data_types);
Review Comment:
We can retrieve the return type from `ScalarFunctionArgs`
https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarFunctionArgs.html#method.return_type
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -302,59 +421,85 @@ pub(crate) fn simplify_concat(args: Vec<Expr>) ->
Result<ExprSimplifyResult> {
ConcatFunc::new().return_type(&data_types)
}?;
- for arg in args.clone() {
+ for arg in args.iter() {
match arg {
Expr::Literal(ScalarValue::Utf8(None), _) => {}
- Expr::Literal(ScalarValue::LargeUtf8(None), _) => {
- }
- Expr::Literal(ScalarValue::Utf8View(None), _) => { }
+ Expr::Literal(ScalarValue::LargeUtf8(None), _) => {}
+ Expr::Literal(ScalarValue::Utf8View(None), _) => {}
// filter out `null` args
// All literals have been converted to Utf8 or LargeUtf8 in
type_coercion.
// Concatenate it with the `contiguous_scalar`.
Expr::Literal(ScalarValue::Utf8(Some(v)), _) => {
- contiguous_scalar += &v;
+ contiguous_scalar += v;
}
Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => {
- contiguous_scalar += &v;
+ contiguous_scalar += v;
}
Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => {
- contiguous_scalar += &v;
+ contiguous_scalar += v;
}
- Expr::Literal(x, _) => {
- return internal_err!(
- "The scalar {x} should be casted to string type during the
type coercion."
- )
+ Expr::Literal(scalar_val, _) => {
+ // Convert non-string, non-array literals to their string
representation
+ // Skip array literals - they should be handled at runtime
Review Comment:
I'm fairly sure type coercion happens before simplification
e.g.
```sql
> explain verbose select 1;
+------------------------------------------------------------+--------------------------+
| plan_type | plan
|
+------------------------------------------------------------+--------------------------+
| initial_logical_plan | Projection:
Int64(1) |
| |
EmptyRelation: rows=1 |
| logical_plan after resolve_grouping_function | SAME TEXT AS
ABOVE |
| logical_plan after type_coercion | SAME TEXT AS
ABOVE |
| analyzed_logical_plan | SAME TEXT AS
ABOVE |
| logical_plan after optimize_unions | SAME TEXT AS
ABOVE |
| logical_plan after simplify_expressions | SAME TEXT AS
ABOVE |
| logical_plan after replace_distinct_aggregate | SAME TEXT AS
ABOVE |
```
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -88,37 +143,91 @@ impl ScalarUDFImpl for ConcatFunc {
&self.signature
}
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ use DataType::*;
+
+ if arg_types.is_empty() {
+ return plan_err!("concat requires at least one argument");
+ }
+
+ let has_arrays = arg_types
+ .iter()
+ .any(|dt| matches!(dt, List(_) | LargeList(_) | FixedSizeList(_,
_)));
+ let has_non_arrays = arg_types
+ .iter()
+ .any(|dt| !matches!(dt, List(_) | LargeList(_) | FixedSizeList(_,
_) | Null));
+
+ if has_arrays && has_non_arrays {
+ return plan_err!(
+ "Cannot mix array and non-array arguments in concat function."
+ );
+ }
+
+ if has_arrays {
+ return Ok(arg_types.to_vec());
+ }
+
+ let target_type = self.get_string_type_precedence(arg_types);
+
+ // Only coerce types that need coercion, keep string types as-is
+ let coerced_types = arg_types
+ .iter()
+ .map(|data_type| match data_type {
+ Utf8View | Utf8 | LargeUtf8 => data_type.clone(),
+ _ => target_type.clone(),
+ })
+ .collect();
+ Ok(coerced_types)
+ }
+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
- let mut dt = &Utf8;
- arg_types.iter().for_each(|data_type| {
- if data_type == &Utf8View {
- dt = data_type;
- }
- if data_type == &LargeUtf8 && dt != &Utf8View {
- dt = data_type;
- }
- });
- Ok(dt.to_owned())
+ if arg_types.is_empty() {
+ return plan_err!("concat requires at least one argument");
+ }
+
+ // After coercion, all arguments have the same type category, so check
only the first
+ if let List(field) | LargeList(field) | FixedSizeList(field, _) =
&arg_types[0] {
+ return Ok(List(Arc::new(arrow::datatypes::Field::new(
+ "item",
+ field.data_type().clone(),
+ true,
+ ))));
+ }
+
+ // For non-array arguments, return string type based on precedence
+ let dt = self.get_string_type_precedence(arg_types);
+ Ok(dt)
}
/// Concatenates the text representations of all the arguments. NULL
arguments are ignored.
/// concat('abcde', 2, NULL, 22) = 'abcde222'
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
+ use DataType::*;
let ScalarFunctionArgs { args, .. } = args;
- let mut return_datatype = DataType::Utf8;
- args.iter().for_each(|col| {
- if col.data_type() == DataType::Utf8View {
- return_datatype = col.data_type();
- }
- if col.data_type() == DataType::LargeUtf8
- && return_datatype != DataType::Utf8View
- {
- return_datatype = col.data_type();
- }
- });
+ if args.is_empty() {
+ return plan_err!("concat requires at least one argument");
+ }
+
+ // After coercion, all arguments have the same type category, so check
only the first
+ let is_array = match &args[0] {
Review Comment:
ColumnarValue has a datatype method
https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.ColumnarValue.html#method.data_type
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -65,13 +71,62 @@ impl Default for ConcatFunc {
impl ConcatFunc {
pub fn new() -> Self {
- use DataType::*;
Self {
- signature: Signature::variadic(
- vec![Utf8View, Utf8, LargeUtf8],
- Volatility::Immutable,
- ),
+ signature: Signature::user_defined(Volatility::Immutable),
+ }
+ }
+
+ /// Get the string type with highest precedence: Utf8View > LargeUtf8 >
Utf8
+ ///
+ /// Utf8View is preferred for performance (zero-copy views),
+ /// LargeUtf8 supports larger strings (i64 offsets),
+ /// Utf8 is the fallback standard string type
+ fn get_string_type_precedence(&self, arg_types: &[DataType]) -> DataType {
+ use DataType::*;
+
+ for data_type in arg_types {
+ if data_type == &Utf8View {
+ return Utf8View;
+ }
+ }
+
+ for data_type in arg_types {
+ if data_type == &LargeUtf8 {
+ return LargeUtf8;
+ }
}
+
+ Utf8
+ }
+
+ /// Concatenate array arguments
+ fn concat_arrays(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ if args.is_empty() {
+ return plan_err!("concat requires at least one argument");
+ }
+
+ // Convert ColumnarValue arguments to ArrayRef
Review Comment:
Can we please remove these LLM comments that add no value.
##########
datafusion/spark/src/function/string/concat.rs:
##########
@@ -119,7 +124,43 @@ fn spark_concat(args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
// If all scalars and any is NULL, return NULL immediately
if matches!(null_mask, NullMaskResolution::ReturnNull) {
- return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
+ // First check if we're dealing with array types by delegating to
ConcatFunc
+ let concat_func = ConcatFunc::new();
+ let return_type = concat_func.return_type(
+ &arg_values
+ .iter()
+ .map(|arg| arg.data_type())
+ .collect::<Vec<_>>(),
+ )?;
+
+ // Return appropriate null value based on return type
+ return Ok(ColumnarValue::Scalar(match return_type {
Review Comment:
We can simplify this using `ScalarValue::try_new_null`
https://docs.rs/datafusion/latest/datafusion/common/enum.ScalarValue.html#method.try_new_null
##########
datafusion/spark/src/function/string/concat.rs:
##########
@@ -119,7 +124,43 @@ fn spark_concat(args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
// If all scalars and any is NULL, return NULL immediately
if matches!(null_mask, NullMaskResolution::ReturnNull) {
- return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
+ // First check if we're dealing with array types by delegating to
ConcatFunc
+ let concat_func = ConcatFunc::new();
+ let return_type = concat_func.return_type(
Review Comment:
We should be getting this information from `ScalarFunctionArgs`
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -501,4 +645,120 @@ mod tests {
}
Ok(())
}
+
+ #[test]
+ fn test_concat_with_integers() -> Result<()> {
+ use datafusion_common::config::ConfigOptions;
+
+ let args = vec![
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some("abc".to_string()))),
+ ColumnarValue::Scalar(ScalarValue::Int64(Some(123))),
+ ColumnarValue::Scalar(ScalarValue::Utf8(None)), // NULL
+ ColumnarValue::Scalar(ScalarValue::Int64(Some(456))),
+ ];
+
+ let arg_fields = vec![
+ Field::new("a", Utf8, true),
+ Field::new("b", Int64, true),
+ Field::new("c", Utf8, true),
+ Field::new("d", Int64, true),
+ ]
+ .into_iter()
+ .map(Arc::new)
+ .collect();
+
+ let func_args = ScalarFunctionArgs {
+ args,
+ arg_fields,
+ number_rows: 1,
+ return_field: Field::new("f", Utf8, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+
+ let result = ConcatFunc::new().invoke_with_args(func_args)?;
+
+ // Expected result should be "abc123456"
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
+ assert_eq!(s, "abc123456");
+ }
+ _ => panic!("Expected scalar UTF8 result, got {result:?}"),
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_array_concatenation_comprehensive() -> Result<()> {
Review Comment:
Could we move these tests to SLTs please
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -139,10 +414,14 @@ impl ScalarUDFImpl for ConcatFunc {
match scalar.try_as_str() {
Some(Some(v)) => result.push_str(v),
Some(None) => {} // null literal
- None => plan_err!(
- "Concat function does not support scalar type {}",
- scalar
- )?,
+ None => {
+ // For non-string types, convert to string
representation
+ if scalar.is_null() {
+ // Skip null values
+ } else {
+ result.push_str(&format!("{scalar}"));
+ }
+ }
Review Comment:
I still don't quite understand why this change was necessary?
##########
datafusion/functions/src/string/concat.rs:
##########
@@ -65,13 +71,64 @@ impl Default for ConcatFunc {
impl ConcatFunc {
pub fn new() -> Self {
- use DataType::*;
Self {
- signature: Signature::variadic(
- vec![Utf8View, Utf8, LargeUtf8],
- Volatility::Immutable,
- ),
+ signature: Signature::user_defined(Volatility::Immutable),
+ }
+ }
+
+ /// Get the string type with highest precedence: Utf8View > LargeUtf8 >
Utf8
+ fn get_string_type_precedence(&self, arg_types: &[DataType]) -> DataType {
+ use DataType::*;
+
+ for data_type in arg_types {
+ if data_type == &Utf8View {
+ return Utf8View;
+ }
+ }
+
+ for data_type in arg_types {
+ if data_type == &LargeUtf8 {
+ return LargeUtf8;
+ }
+ }
+
+ Utf8
+ }
+
+ /// Concatenate array arguments
+ fn concat_arrays(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ if args.is_empty() {
+ return plan_err!("concat requires at least one argument");
}
+
+ // Convert ColumnarValue arguments to ArrayRef
+ let arrays: Result<Vec<Arc<dyn Array>>> = args
+ .iter()
+ .map(|arg| match arg {
+ ColumnarValue::Array(arr) => Ok(Arc::clone(arr)),
+ ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
+ })
+ .collect();
+ let arrays = arrays?;
+
+ // Check if all arrays are null - concat errors in this case
Review Comment:
> Yes, this matches PostgreSQL's behavior.
Based on what? Do you have an example query that shows this? Because I
tested this against postgres 18 but cannot replicate it:
```sql
postgres=# select array_cat(null::integer[], null::integer[]);
array_cat
-----------
(1 row)
postgres=# select array_cat(null, null);
array_cat
-----------
(1 row)
postgres=# select concat(null, null);
concat
--------
(1 row)
postgres=# select concat(null::integer[], null::integer[]);
concat
--------
(1 row)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]