neilconway commented on code in PR #20754:
URL: https://github.com/apache/datafusion/pull/20754#discussion_r2961107284


##########
datafusion/functions/src/unicode/strpos.rs:
##########
@@ -127,142 +135,201 @@ impl ScalarUDFImpl for StrposFunc {
 }
 
 fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
-    match (args[0].data_type(), args[1].data_type()) {
-        (DataType::Utf8, DataType::Utf8) => {
-            let string_array = args[0].as_string::<i32>();
-            let substring_array = args[1].as_string::<i32>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8, DataType::Utf8View) => {
-            let string_array = args[0].as_string::<i32>();
-            let substring_array = args[1].as_string_view();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8, DataType::LargeUtf8) => {
-            let string_array = args[0].as_string::<i32>();
-            let substring_array = args[1].as_string::<i64>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::LargeUtf8, DataType::Utf8) => {
-            let string_array = args[0].as_string::<i64>();
-            let substring_array = args[1].as_string::<i32>();
-            calculate_strpos::<_, _, Int64Type>(&string_array, 
&substring_array)
-        }
-        (DataType::LargeUtf8, DataType::Utf8View) => {
-            let string_array = args[0].as_string::<i64>();
-            let substring_array = args[1].as_string_view();
-            calculate_strpos::<_, _, Int64Type>(&string_array, 
&substring_array)
-        }
-        (DataType::LargeUtf8, DataType::LargeUtf8) => {
-            let string_array = args[0].as_string::<i64>();
-            let substring_array = args[1].as_string::<i64>();
-            calculate_strpos::<_, _, Int64Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8View, DataType::Utf8View) => {
-            let string_array = args[0].as_string_view();
-            let substring_array = args[1].as_string_view();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8View, DataType::Utf8) => {
-            let string_array = args[0].as_string_view();
-            let substring_array = args[1].as_string::<i32>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8View, DataType::LargeUtf8) => {
-            let string_array = args[0].as_string_view();
-            let substring_array = args[1].as_string::<i64>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
+    /// Dispatches the needle array to the correct string type and calls
+    /// `strpos_general` with the given haystack and result type.
+    macro_rules! dispatch_needle {
+        ($haystack:expr, $result_type:ty, $args:expr) => {
+            match $args[1].data_type() {
+                DataType::Utf8 => strpos_general::<_, _, $result_type>(
+                    $haystack,
+                    $args[1].as_string::<i32>(),
+                ),
+                DataType::LargeUtf8 => strpos_general::<_, _, $result_type>(
+                    $haystack,
+                    $args[1].as_string::<i64>(),
+                ),
+                DataType::Utf8View => strpos_general::<_, _, $result_type>(
+                    $haystack,
+                    $args[1].as_string_view(),
+                ),
+                other => exec_err!(
+                    "Unsupported data type {other:?} for function strpos 
needle"
+                ),
+            }
+        };
+    }
 
+    match args[0].data_type() {
+        DataType::Utf8 => dispatch_needle!(args[0].as_string::<i32>(), 
Int32Type, args),
+        DataType::LargeUtf8 => {
+            dispatch_needle!(args[0].as_string::<i64>(), Int64Type, args)
+        }
+        DataType::Utf8View => dispatch_needle!(args[0].as_string_view(), 
Int32Type, args),
         other => {
-            exec_err!("Unsupported data type combination {other:?} for 
function strpos")
+            exec_err!("Unsupported data type {other:?} for function strpos 
haystack")
         }
     }
 }
 
 /// Find `needle` in `haystack` using `memchr` to quickly skip to positions
-/// where the first byte matches, then verify the remaining bytes. Using
-/// string::find is slower because it has significant per-call overhead that
-/// `memchr` does not, and strpos is often invoked many times on short inputs.
-/// Returns a 1-based position, or 0 if not found.
-/// Both inputs must be ASCII-only.
-fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
+/// where the first byte matches, then verify the remaining bytes. Returns
+/// the 0-based byte offset of the match, or `None` if not found. An empty
+/// `needle` matches at offset 0.
+fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
     let needle_len = needle.len();
+    let haystack_len = haystack.len();
+
+    if needle_len == 0 {
+        return Some(0);
+    }
+    if needle_len > haystack_len {
+        return None;
+    }
+
     let first_byte = needle[0];
     let mut offset = 0;
 
     while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
         let start = offset + pos;
         if start + needle_len > haystack.len() {
-            return 0;
+            return None;
         }
         if haystack[start..start + needle_len] == *needle {
-            return start + 1;
+            return Some(start);
         }
         offset = start + 1;
     }
 
-    0
+    None
 }
 
-/// Returns starting index of specified substring within string, or zero if 
it's not present. (Same as position(substring in string), but note the reversed 
argument order.)
-/// strpos('high', 'ig') = 2
-/// The implementation uses UTF-8 code points as characters
-fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
-    string_array: &V1,
-    substring_array: &V2,
+/// Fallback strpos implementation for when both haystack and needle are 
arrays.
+/// Building a new `memmem::Finder` for every row is too expensive; it is 
faster
+/// to use `memchr::memchr`.
+fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>(
+    haystack_array: V1,
+    needle_array: V2,
 ) -> Result<ArrayRef>
 where
-    V1: StringArrayType<'a, Item = &'a str>,
-    V2: StringArrayType<'a, Item = &'a str>,
+    V1: StringArrayType<'a, Item = &'a str> + Copy,
+    V2: StringArrayType<'a, Item = &'a str> + Copy,
 {
-    let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
-    let string_iter = string_array.iter();
-    let substring_iter = substring_array.iter();
-
-    let result = string_iter
-        .zip(substring_iter)
-        .map(|(string, substring)| match (string, substring) {
-            (Some(string), Some(substring)) => {
-                if substring.is_empty() {
-                    return T::Native::from_usize(1);
+    let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii();
+    let haystack_iter = haystack_array.iter();
+    let needle_iter = needle_array.iter();
+
+    let result = haystack_iter
+        .zip(needle_iter)
+        .map(|(haystack, needle)| match (haystack, needle) {
+            (Some(haystack), Some(needle)) => {
+                let haystack_bytes = haystack.as_bytes();
+                let needle_bytes = needle.as_bytes();
+
+                match find_substring_bytes(haystack_bytes, needle_bytes) {
+                    None => T::Native::from_usize(0),
+                    Some(byte_offset) => {
+                        if ascii_only {
+                            T::Native::from_usize(byte_offset + 1)
+                        } else {
+                            // SAFETY: haystack_bytes is valid UTF-8
+                            let prefix = unsafe {
+                                std::str::from_utf8_unchecked(
+                                    &haystack_bytes[..byte_offset],
+                                )
+                            };
+                            T::Native::from_usize(prefix.chars().count() + 1)
+                        }
+                    }
                 }
+            }
+            _ => None,
+        })
+        .collect::<PrimitiveArray<T>>();
 
-                let substring_bytes = substring.as_bytes();
-                let string_bytes = string.as_bytes();
+    Ok(Arc::new(result) as ArrayRef)
+}
 
-                if substring_bytes.len() > string_bytes.len() {
-                    return T::Native::from_usize(0);
-                }
+/// Fast-path strpos implementation for when the haystack is an array and the
+/// needle is a scalar.  We can pre-build a `memmem::Finder` once and reuse it
+/// for every haystack row.
+fn strpos_scalar_needle(
+    haystack_array: &ArrayRef,
+    needle_scalar: &ScalarValue,
+) -> Result<ColumnarValue> {
+    let Some(needle_str) = needle_scalar.try_as_str() else {
+        return exec_err!(
+            "Unsupported data type {needle_scalar:?} for function strpos 
needle"
+        );
+    };
+
+    // Null needle => null result for every row
+    let Some(needle_str) = needle_str else {
+        return match haystack_array.data_type() {
+            DataType::LargeUtf8 => {
+                Ok(ColumnarValue::Array(Arc::new(
+                    
PrimitiveArray::<Int64Type>::new_null(haystack_array.len()),
+                )))
+            }
+            _ => Ok(ColumnarValue::Array(Arc::new(
+                PrimitiveArray::<Int32Type>::new_null(haystack_array.len()),
+            ))),
+        };
+    };
+
+    let result = match haystack_array.data_type() {
+        DataType::Utf8 => strpos_with_finder::<_, Int32Type>(
+            haystack_array.as_string::<i32>(),
+            needle_str,
+        ),
+        DataType::LargeUtf8 => strpos_with_finder::<_, Int64Type>(
+            haystack_array.as_string::<i64>(),
+            needle_str,
+        ),
+        DataType::Utf8View => strpos_with_finder::<_, Int32Type>(
+            haystack_array.as_string_view(),
+            needle_str,
+        ),
+        other => {
+            exec_err!("Unsupported data type {other:?} for function strpos")
+        }
+    }?;
+    Ok(ColumnarValue::Array(result))
+}
 
-                if ascii_only {
-                    T::Native::from_usize(find_ascii_substring(
-                        string_bytes,
-                        substring_bytes,
-                    ))
-                } else {
-                    // For non-ASCII, use a single-pass search that tracks both
-                    // byte position and character position simultaneously
-                    let mut char_pos = 0;
-                    for (byte_idx, _) in string.char_indices() {
-                        char_pos += 1;
-                        if byte_idx + substring_bytes.len() <= 
string_bytes.len() {
-                            // SAFETY: We just checked that byte_idx + 
substring_bytes.len() <= string_bytes.len()
-                            let slice = unsafe {
-                                string_bytes.get_unchecked(
-                                    byte_idx..byte_idx + substring_bytes.len(),
+fn strpos_with_finder<'a, V, T: ArrowPrimitiveType>(
+    haystack_array: V,
+    needle: &str,
+) -> Result<ArrayRef>
+where
+    V: StringArrayType<'a, Item = &'a str> + Copy,
+{
+    let needle_bytes = needle.as_bytes();
+    let ascii_haystack = haystack_array.is_ascii();
+    let finder = memmem::Finder::new(needle_bytes);
+
+    let result = haystack_array
+        .iter()
+        .map(|string| match string {
+            Some(string) => {
+                let haystack_bytes = string.as_bytes();
+                match finder.find(haystack_bytes) {
+                    None => T::Native::from_usize(0),
+                    Some(byte_offset) => {
+                        if ascii_haystack {

Review Comment:
   This is a fair point, although this is how the code worked before -- we 
check `is_ascii` for the entire haystack. It might be a net win to change it, 
but it would probably benefit from some analysis/benchmarking.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to