alamb commented on code in PR #12093:
URL: https://github.com/apache/datafusion/pull/12093#discussion_r1726774561


##########
datafusion/functions/src/string/split_part.rs:
##########
@@ -82,36 +84,121 @@ impl ScalarUDFImpl for SplitPartFunc {
     }
 
     fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
-        match (args[0].data_type(), args[1].data_type()) {
-            (
-                DataType::Utf8 | DataType::Utf8View,
-                DataType::Utf8 | DataType::Utf8View,
-            ) => make_scalar_function(split_part::<i32, i32>, vec![])(args),
+        // First, determine if any of the arguments is an Array
+        let len = args.iter().find_map(|arg| match arg {
+            ColumnarValue::Array(a) => Some(a.len()),
+            _ => None,
+        });
+
+        let inferred_length = len.unwrap_or(1);
+        let is_scalar = len.is_none();
+
+        // Convert all ColumnarValues to ArrayRefs
+        let args = args
+            .iter()
+            .map(|arg| match arg {
+                ColumnarValue::Scalar(scalar) => 
scalar.to_array_of_size(inferred_length),
+                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        // Unpack the ArrayRefs from the arguments
+        let n_array = as_int64_array(&args[2])?;
+        let result = match (args[0].data_type(), args[1].data_type()) {
+            (DataType::Utf8View, DataType::Utf8View) => {
+                split_part_impl::<&StringViewArray, &StringViewArray, i32>(
+                    args[0].as_string_view(),
+                    args[1].as_string_view(),
+                    n_array,
+                )
+            }
+            (DataType::Utf8View, DataType::Utf8) => {
+                split_part_impl::<&StringViewArray, &GenericStringArray<i32>, 
i32>(
+                    args[0].as_string_view(),
+                    args[1].as_string::<i32>(),
+                    n_array,
+                )
+            }
+            (DataType::Utf8View, DataType::LargeUtf8) => {
+                split_part_impl::<&StringViewArray, &GenericStringArray<i64>, 
i32>(
+                    args[0].as_string_view(),
+                    args[1].as_string::<i64>(),
+                    n_array,
+                )
+            }
+            (DataType::Utf8, DataType::Utf8View) => {
+                split_part_impl::<&GenericStringArray<i32>, &StringViewArray, 
i32>(
+                    args[0].as_string::<i32>(),
+                    args[1].as_string_view(),
+                    n_array,
+                )
+            }
+            (DataType::LargeUtf8, DataType::Utf8View) => {
+                split_part_impl::<&GenericStringArray<i64>, &StringViewArray, 
i64>(
+                    args[0].as_string::<i64>(),
+                    args[1].as_string_view(),
+                    n_array,
+                )
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                split_part_impl::<&GenericStringArray<i32>, 
&GenericStringArray<i32>, i32>(
+                    args[0].as_string::<i32>(),
+                    args[1].as_string::<i32>(),
+                    n_array,
+                )
+            }
             (DataType::LargeUtf8, DataType::LargeUtf8) => {
-                make_scalar_function(split_part::<i64, i64>, vec![])(args)
+                split_part_impl::<&GenericStringArray<i64>, 
&GenericStringArray<i64>, i64>(
+                    args[0].as_string::<i64>(),
+                    args[1].as_string::<i64>(),
+                    n_array,
+                )
             }
-            (_, DataType::LargeUtf8) => {
-                make_scalar_function(split_part::<i32, i64>, vec![])(args)
+            (DataType::Utf8, DataType::LargeUtf8) => {
+                split_part_impl::<&GenericStringArray<i32>, 
&GenericStringArray<i64>, i32>(
+                    args[0].as_string::<i32>(),
+                    args[1].as_string::<i64>(),
+                    n_array,
+                )
             }
-            (DataType::LargeUtf8, _) => {
-                make_scalar_function(split_part::<i64, i32>, vec![])(args)
+            (DataType::LargeUtf8, DataType::Utf8) => {
+                split_part_impl::<&GenericStringArray<i64>, 
&GenericStringArray<i32>, i64>(
+                    args[0].as_string::<i64>(),
+                    args[1].as_string::<i32>(),
+                    n_array,
+                )
             }
-            (first_type, second_type) => exec_err!(
-                "unsupported first type {} and second type {} for split_part 
function",
-                first_type,
-                second_type
-            ),
+            _ => exec_err!("Unsupported combination of argument types for 
split_part"),
+        };
+        if is_scalar {
+            // If all inputs are scalar, keep the output as scalar
+            let result = result.and_then(|arr| 
ScalarValue::try_from_array(&arr, 0));
+            result.map(ColumnarValue::Scalar)
+        } else {
+            result.map(ColumnarValue::Array)
         }
     }
 }
 
-macro_rules! process_split_part {
-    ($string_array: expr, $delimiter_array: expr, $n_array: expr) => {{
-        let result = $string_array
-            .iter()
-            .zip($delimiter_array.iter())
-            .zip($n_array.iter())
-            .map(|((string, delimiter), n)| match (string, delimiter, n) {
+/// impl
+pub fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>(

Review Comment:
   I think this is a nice pattern. Thanks @Lordworms 



##########
datafusion/functions/src/string/split_part.rs:
##########
@@ -82,36 +84,121 @@ impl ScalarUDFImpl for SplitPartFunc {
     }
 
     fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
-        match (args[0].data_type(), args[1].data_type()) {
-            (
-                DataType::Utf8 | DataType::Utf8View,
-                DataType::Utf8 | DataType::Utf8View,
-            ) => make_scalar_function(split_part::<i32, i32>, vec![])(args),
+        // First, determine if any of the arguments is an Array
+        let len = args.iter().find_map(|arg| match arg {
+            ColumnarValue::Array(a) => Some(a.len()),
+            _ => None,
+        });
+
+        let inferred_length = len.unwrap_or(1);
+        let is_scalar = len.is_none();
+
+        // Convert all ColumnarValues to ArrayRefs
+        let args = args
+            .iter()
+            .map(|arg| match arg {
+                ColumnarValue::Scalar(scalar) => 
scalar.to_array_of_size(inferred_length),
+                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        // Unpack the ArrayRefs from the arguments
+        let n_array = as_int64_array(&args[2])?;
+        let result = match (args[0].data_type(), args[1].data_type()) {

Review Comment:
   Nice



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to