alamb commented on code in PR #12093: URL: https://github.com/apache/datafusion/pull/12093#discussion_r1726774561
########## datafusion/functions/src/string/split_part.rs: ########## @@ -82,36 +84,121 @@ impl ScalarUDFImpl for SplitPartFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { - match (args[0].data_type(), args[1].data_type()) { - ( - DataType::Utf8 | DataType::Utf8View, - DataType::Utf8 | DataType::Utf8View, - ) => make_scalar_function(split_part::<i32, i32>, vec![])(args), + // First, determine if any of the arguments is an Array + let len = args.iter().find_map(|arg| match arg { + ColumnarValue::Array(a) => Some(a.len()), + _ => None, + }); + + let inferred_length = len.unwrap_or(1); + let is_scalar = len.is_none(); + + // Convert all ColumnarValues to ArrayRefs + let args = args + .iter() + .map(|arg| match arg { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length), + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + }) + .collect::<Result<Vec<_>>>()?; + + // Unpack the ArrayRefs from the arguments + let n_array = as_int64_array(&args[2])?; + let result = match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8View, DataType::Utf8View) => { + split_part_impl::<&StringViewArray, &StringViewArray, i32>( + args[0].as_string_view(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::Utf8View, DataType::Utf8) => { + split_part_impl::<&StringViewArray, &GenericStringArray<i32>, i32>( + args[0].as_string_view(), + args[1].as_string::<i32>(), + n_array, + ) + } + (DataType::Utf8View, DataType::LargeUtf8) => { + split_part_impl::<&StringViewArray, &GenericStringArray<i64>, i32>( + args[0].as_string_view(), + args[1].as_string::<i64>(), + n_array, + ) + } + (DataType::Utf8, DataType::Utf8View) => { + split_part_impl::<&GenericStringArray<i32>, &StringViewArray, i32>( + args[0].as_string::<i32>(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::LargeUtf8, DataType::Utf8View) => { + split_part_impl::<&GenericStringArray<i64>, &StringViewArray, i64>( + args[0].as_string::<i64>(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::Utf8, DataType::Utf8) => { + split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i32>, i32>( + args[0].as_string::<i32>(), + args[1].as_string::<i32>(), + n_array, + ) + } (DataType::LargeUtf8, DataType::LargeUtf8) => { - make_scalar_function(split_part::<i64, i64>, vec![])(args) + split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i64>, i64>( + args[0].as_string::<i64>(), + args[1].as_string::<i64>(), + n_array, + ) } - (_, DataType::LargeUtf8) => { - make_scalar_function(split_part::<i32, i64>, vec![])(args) + (DataType::Utf8, DataType::LargeUtf8) => { + split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i64>, i32>( + args[0].as_string::<i32>(), + args[1].as_string::<i64>(), + n_array, + ) } - (DataType::LargeUtf8, _) => { - make_scalar_function(split_part::<i64, i32>, vec![])(args) + (DataType::LargeUtf8, DataType::Utf8) => { + split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i32>, i64>( + args[0].as_string::<i64>(), + args[1].as_string::<i32>(), + n_array, + ) } - (first_type, second_type) => exec_err!( - "unsupported first type {} and second type {} for split_part function", - first_type, - second_type - ), + _ => exec_err!("Unsupported combination of argument types for split_part"), + }; + if is_scalar { + // If all inputs are scalar, keep the output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) } } } -macro_rules! process_split_part { - ($string_array: expr, $delimiter_array: expr, $n_array: expr) => {{ - let result = $string_array - .iter() - .zip($delimiter_array.iter()) - .zip($n_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { +/// impl +pub fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>( Review Comment: I think this is a nice pattern. Thanks @Lordworms ########## datafusion/functions/src/string/split_part.rs: ########## @@ -82,36 +84,121 @@ impl ScalarUDFImpl for SplitPartFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { - match (args[0].data_type(), args[1].data_type()) { - ( - DataType::Utf8 | DataType::Utf8View, - DataType::Utf8 | DataType::Utf8View, - ) => make_scalar_function(split_part::<i32, i32>, vec![])(args), + // First, determine if any of the arguments is an Array + let len = args.iter().find_map(|arg| match arg { + ColumnarValue::Array(a) => Some(a.len()), + _ => None, + }); + + let inferred_length = len.unwrap_or(1); + let is_scalar = len.is_none(); + + // Convert all ColumnarValues to ArrayRefs + let args = args + .iter() + .map(|arg| match arg { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length), + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + }) + .collect::<Result<Vec<_>>>()?; + + // Unpack the ArrayRefs from the arguments + let n_array = as_int64_array(&args[2])?; + let result = match (args[0].data_type(), args[1].data_type()) { Review Comment: Nice -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org