neilconway commented on code in PR #20754:
URL: https://github.com/apache/datafusion/pull/20754#discussion_r2961081968


##########
datafusion/functions/src/unicode/strpos.rs:
##########
@@ -127,142 +135,201 @@ impl ScalarUDFImpl for StrposFunc {
 }
 
 fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
-    match (args[0].data_type(), args[1].data_type()) {
-        (DataType::Utf8, DataType::Utf8) => {
-            let string_array = args[0].as_string::<i32>();
-            let substring_array = args[1].as_string::<i32>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8, DataType::Utf8View) => {
-            let string_array = args[0].as_string::<i32>();
-            let substring_array = args[1].as_string_view();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8, DataType::LargeUtf8) => {
-            let string_array = args[0].as_string::<i32>();
-            let substring_array = args[1].as_string::<i64>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::LargeUtf8, DataType::Utf8) => {
-            let string_array = args[0].as_string::<i64>();
-            let substring_array = args[1].as_string::<i32>();
-            calculate_strpos::<_, _, Int64Type>(&string_array, 
&substring_array)
-        }
-        (DataType::LargeUtf8, DataType::Utf8View) => {
-            let string_array = args[0].as_string::<i64>();
-            let substring_array = args[1].as_string_view();
-            calculate_strpos::<_, _, Int64Type>(&string_array, 
&substring_array)
-        }
-        (DataType::LargeUtf8, DataType::LargeUtf8) => {
-            let string_array = args[0].as_string::<i64>();
-            let substring_array = args[1].as_string::<i64>();
-            calculate_strpos::<_, _, Int64Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8View, DataType::Utf8View) => {
-            let string_array = args[0].as_string_view();
-            let substring_array = args[1].as_string_view();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8View, DataType::Utf8) => {
-            let string_array = args[0].as_string_view();
-            let substring_array = args[1].as_string::<i32>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
-        (DataType::Utf8View, DataType::LargeUtf8) => {
-            let string_array = args[0].as_string_view();
-            let substring_array = args[1].as_string::<i64>();
-            calculate_strpos::<_, _, Int32Type>(&string_array, 
&substring_array)
-        }
+    /// Dispatches the needle array to the correct string type and calls
+    /// `strpos_general` with the given haystack and result type.
+    macro_rules! dispatch_needle {
+        ($haystack:expr, $result_type:ty, $args:expr) => {
+            match $args[1].data_type() {
+                DataType::Utf8 => strpos_general::<_, _, $result_type>(
+                    $haystack,
+                    $args[1].as_string::<i32>(),
+                ),
+                DataType::LargeUtf8 => strpos_general::<_, _, $result_type>(
+                    $haystack,
+                    $args[1].as_string::<i64>(),
+                ),
+                DataType::Utf8View => strpos_general::<_, _, $result_type>(
+                    $haystack,
+                    $args[1].as_string_view(),
+                ),
+                other => exec_err!(
+                    "Unsupported data type {other:?} for function strpos 
needle"
+                ),
+            }
+        };
+    }
 
+    match args[0].data_type() {
+        DataType::Utf8 => dispatch_needle!(args[0].as_string::<i32>(), 
Int32Type, args),
+        DataType::LargeUtf8 => {
+            dispatch_needle!(args[0].as_string::<i64>(), Int64Type, args)
+        }
+        DataType::Utf8View => dispatch_needle!(args[0].as_string_view(), 
Int32Type, args),
         other => {
-            exec_err!("Unsupported data type combination {other:?} for 
function strpos")
+            exec_err!("Unsupported data type {other:?} for function strpos 
haystack")
         }
     }
 }
 
 /// Find `needle` in `haystack` using `memchr` to quickly skip to positions
-/// where the first byte matches, then verify the remaining bytes. Using
-/// string::find is slower because it has significant per-call overhead that
-/// `memchr` does not, and strpos is often invoked many times on short inputs.
-/// Returns a 1-based position, or 0 if not found.
-/// Both inputs must be ASCII-only.
-fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
+/// where the first byte matches, then verify the remaining bytes. Returns
+/// the 0-based byte offset of the match, or `None` if not found. An empty
+/// `needle` matches at offset 0.
+fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
     let needle_len = needle.len();
+    let haystack_len = haystack.len();
+
+    if needle_len == 0 {
+        return Some(0);
+    }
+    if needle_len > haystack_len {
+        return None;
+    }
+
     let first_byte = needle[0];
     let mut offset = 0;
 
     while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
         let start = offset + pos;
         if start + needle_len > haystack.len() {
-            return 0;
+            return None;
         }
         if haystack[start..start + needle_len] == *needle {
-            return start + 1;
+            return Some(start);
         }
         offset = start + 1;
     }
 
-    0
+    None
 }
 
-/// Returns starting index of specified substring within string, or zero if 
it's not present. (Same as position(substring in string), but note the reversed 
argument order.)
-/// strpos('high', 'ig') = 2
-/// The implementation uses UTF-8 code points as characters
-fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
-    string_array: &V1,
-    substring_array: &V2,
+/// Fallback strpos implementation for when both haystack and needle are 
arrays.
+/// Building a new `memmem::Finder` for every row is too expensive; it is 
faster
+/// to use `memchr::memchr`.
+fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>(
+    haystack_array: V1,
+    needle_array: V2,
 ) -> Result<ArrayRef>
 where
-    V1: StringArrayType<'a, Item = &'a str>,
-    V2: StringArrayType<'a, Item = &'a str>,
+    V1: StringArrayType<'a, Item = &'a str> + Copy,
+    V2: StringArrayType<'a, Item = &'a str> + Copy,
 {
-    let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
-    let string_iter = string_array.iter();
-    let substring_iter = substring_array.iter();
-
-    let result = string_iter
-        .zip(substring_iter)
-        .map(|(string, substring)| match (string, substring) {
-            (Some(string), Some(substring)) => {
-                if substring.is_empty() {
-                    return T::Native::from_usize(1);
+    let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii();
+    let haystack_iter = haystack_array.iter();
+    let needle_iter = needle_array.iter();
+
+    let result = haystack_iter
+        .zip(needle_iter)
+        .map(|(haystack, needle)| match (haystack, needle) {
+            (Some(haystack), Some(needle)) => {
+                let haystack_bytes = haystack.as_bytes();
+                let needle_bytes = needle.as_bytes();
+
+                match find_substring_bytes(haystack_bytes, needle_bytes) {
+                    None => T::Native::from_usize(0),
+                    Some(byte_offset) => {
+                        if ascii_only {
+                            T::Native::from_usize(byte_offset + 1)
+                        } else {
+                            // SAFETY: haystack_bytes is valid UTF-8
+                            let prefix = unsafe {
+                                std::str::from_utf8_unchecked(
+                                    &haystack_bytes[..byte_offset],
+                                )
+                            };
+                            T::Native::from_usize(prefix.chars().count() + 1)

Review Comment:
   @timsaucer The code is fine as-is, but the `SAFETY` comment merits some 
elaboration, and I guess the defensive `debug_assert!` doesn't hurt, either. 
The concern is that when slicing raw bytes to create a UTF-8 string, we need to 
start at a valid UTF-8 character boundary. Because both `needle` and `haystack` 
are valid UTF-8 and `byte_offset` is the byte index at which we found an 
instance of `needle` in `haystack`, we know that `byte_offset` is a valid UTF-8 
character boundary; this follows from how UTF-8 encoding works.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to