martin-g commented on code in PR #20754:
URL: https://github.com/apache/datafusion/pull/20754#discussion_r2956566830
##########
datafusion/functions/src/unicode/strpos.rs:
##########
@@ -127,142 +135,201 @@ impl ScalarUDFImpl for StrposFunc {
}
fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
- match (args[0].data_type(), args[1].data_type()) {
- (DataType::Utf8, DataType::Utf8) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8, DataType::Utf8View) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8, DataType::LargeUtf8) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::Utf8) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::Utf8View) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::LargeUtf8) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::Utf8View) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::Utf8) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::LargeUtf8) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
+ /// Dispatches the needle array to the correct string type and calls
+ /// `strpos_general` with the given haystack and result type.
+ macro_rules! dispatch_needle {
+ ($haystack:expr, $result_type:ty, $args:expr) => {
+ match $args[1].data_type() {
+ DataType::Utf8 => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string::<i32>(),
+ ),
+ DataType::LargeUtf8 => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string::<i64>(),
+ ),
+ DataType::Utf8View => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string_view(),
+ ),
+ other => exec_err!(
+ "Unsupported data type {other:?} for function strpos
needle"
+ ),
+ }
+ };
+ }
+ match args[0].data_type() {
+ DataType::Utf8 => dispatch_needle!(args[0].as_string::<i32>(),
Int32Type, args),
+ DataType::LargeUtf8 => {
+ dispatch_needle!(args[0].as_string::<i64>(), Int64Type, args)
+ }
+ DataType::Utf8View => dispatch_needle!(args[0].as_string_view(),
Int32Type, args),
other => {
- exec_err!("Unsupported data type combination {other:?} for
function strpos")
+ exec_err!("Unsupported data type {other:?} for function strpos
haystack")
}
}
}
/// Find `needle` in `haystack` using `memchr` to quickly skip to positions
-/// where the first byte matches, then verify the remaining bytes. Using
-/// string::find is slower because it has significant per-call overhead that
-/// `memchr` does not, and strpos is often invoked many times on short inputs.
-/// Returns a 1-based position, or 0 if not found.
-/// Both inputs must be ASCII-only.
-fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
+/// where the first byte matches, then verify the remaining bytes. Returns
+/// the 0-based byte offset of the match, or `None` if not found. An empty
+/// `needle` matches at offset 0.
+fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
let needle_len = needle.len();
+ let haystack_len = haystack.len();
+
+ if needle_len == 0 {
+ return Some(0);
+ }
+ if needle_len > haystack_len {
+ return None;
+ }
+
let first_byte = needle[0];
let mut offset = 0;
while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
let start = offset + pos;
if start + needle_len > haystack.len() {
- return 0;
+ return None;
}
if haystack[start..start + needle_len] == *needle {
- return start + 1;
+ return Some(start);
}
offset = start + 1;
}
- 0
+ None
}
-/// Returns starting index of specified substring within string, or zero if
it's not present. (Same as position(substring in string), but note the reversed
argument order.)
-/// strpos('high', 'ig') = 2
-/// The implementation uses UTF-8 code points as characters
-fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
- string_array: &V1,
- substring_array: &V2,
+/// Fallback strpos implementation for when both haystack and needle are
arrays.
+/// Building a new `memmem::Finder` for every row is too expensive; it is
faster
+/// to use `memchr::memchr`.
+fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>(
+ haystack_array: V1,
+ needle_array: V2,
) -> Result<ArrayRef>
where
- V1: StringArrayType<'a, Item = &'a str>,
- V2: StringArrayType<'a, Item = &'a str>,
+ V1: StringArrayType<'a, Item = &'a str> + Copy,
+ V2: StringArrayType<'a, Item = &'a str> + Copy,
{
- let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
- let string_iter = string_array.iter();
- let substring_iter = substring_array.iter();
-
- let result = string_iter
- .zip(substring_iter)
- .map(|(string, substring)| match (string, substring) {
- (Some(string), Some(substring)) => {
- if substring.is_empty() {
- return T::Native::from_usize(1);
+ let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii();
+ let haystack_iter = haystack_array.iter();
+ let needle_iter = needle_array.iter();
+
+ let result = haystack_iter
+ .zip(needle_iter)
+ .map(|(haystack, needle)| match (haystack, needle) {
+ (Some(haystack), Some(needle)) => {
+ let haystack_bytes = haystack.as_bytes();
+ let needle_bytes = needle.as_bytes();
+
+ match find_substring_bytes(haystack_bytes, needle_bytes) {
+ None => T::Native::from_usize(0),
+ Some(byte_offset) => {
+ if ascii_only {
+ T::Native::from_usize(byte_offset + 1)
+ } else {
+ // SAFETY: haystack_bytes is valid UTF-8
+ let prefix = unsafe {
+ std::str::from_utf8_unchecked(
+ &haystack_bytes[..byte_offset],
+ )
+ };
+ T::Native::from_usize(prefix.chars().count() + 1)
+ }
+ }
}
+ }
+ _ => None,
+ })
+ .collect::<PrimitiveArray<T>>();
- let substring_bytes = substring.as_bytes();
- let string_bytes = string.as_bytes();
+ Ok(Arc::new(result) as ArrayRef)
+}
- if substring_bytes.len() > string_bytes.len() {
- return T::Native::from_usize(0);
- }
+/// Fast-path strpos implementation for when the haystack is an array and the
+/// needle is a scalar. We can pre-build a `memmem::Finder` once and reuse it
+/// for every haystack row.
+fn strpos_scalar_needle(
+ haystack_array: &ArrayRef,
+ needle_scalar: &ScalarValue,
+) -> Result<ColumnarValue> {
+ let Some(needle_str) = needle_scalar.try_as_str() else {
+ return exec_err!(
+ "Unsupported data type {needle_scalar:?} for function strpos
needle"
+ );
+ };
+
+ // Null needle => null result for every row
+ let Some(needle_str) = needle_str else {
+ return match haystack_array.data_type() {
+ DataType::LargeUtf8 => {
+ Ok(ColumnarValue::Array(Arc::new(
+
PrimitiveArray::<Int64Type>::new_null(haystack_array.len()),
+ )))
+ }
+ _ => Ok(ColumnarValue::Array(Arc::new(
+ PrimitiveArray::<Int32Type>::new_null(haystack_array.len()),
+ ))),
+ };
+ };
+
+ let result = match haystack_array.data_type() {
+ DataType::Utf8 => strpos_with_finder::<_, Int32Type>(
+ haystack_array.as_string::<i32>(),
+ needle_str,
+ ),
+ DataType::LargeUtf8 => strpos_with_finder::<_, Int64Type>(
+ haystack_array.as_string::<i64>(),
+ needle_str,
+ ),
+ DataType::Utf8View => strpos_with_finder::<_, Int32Type>(
+ haystack_array.as_string_view(),
+ needle_str,
+ ),
+ other => {
+ exec_err!("Unsupported data type {other:?} for function strpos")
+ }
+ }?;
+ Ok(ColumnarValue::Array(result))
+}
- if ascii_only {
- T::Native::from_usize(find_ascii_substring(
- string_bytes,
- substring_bytes,
- ))
- } else {
- // For non-ASCII, use a single-pass search that tracks both
- // byte position and character position simultaneously
- let mut char_pos = 0;
- for (byte_idx, _) in string.char_indices() {
- char_pos += 1;
- if byte_idx + substring_bytes.len() <=
string_bytes.len() {
- // SAFETY: We just checked that byte_idx +
substring_bytes.len() <= string_bytes.len()
- let slice = unsafe {
- string_bytes.get_unchecked(
- byte_idx..byte_idx + substring_bytes.len(),
+fn strpos_with_finder<'a, V, T: ArrowPrimitiveType>(
+ haystack_array: V,
+ needle: &str,
+) -> Result<ArrayRef>
+where
+ V: StringArrayType<'a, Item = &'a str> + Copy,
+{
+ let needle_bytes = needle.as_bytes();
+ let ascii_haystack = haystack_array.is_ascii();
+ let finder = memmem::Finder::new(needle_bytes);
+
+ let result = haystack_array
+ .iter()
+ .map(|string| match string {
+ Some(string) => {
+ let haystack_bytes = string.as_bytes();
+ match finder.find(haystack_bytes) {
+ None => T::Native::from_usize(0),
+ Some(byte_offset) => {
+ if ascii_haystack {
Review Comment:
Wouldn't it be better to check whether the **current** `string` is ascii ?
Currently with `ascii_haystack==false` all rows will go thru the slow path,
even when the current row is ascii itself. That is, even if just one row is
unicode then all rows go thru the slow path.
##########
datafusion/functions/src/unicode/strpos.rs:
##########
@@ -127,142 +135,201 @@ impl ScalarUDFImpl for StrposFunc {
}
fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
- match (args[0].data_type(), args[1].data_type()) {
- (DataType::Utf8, DataType::Utf8) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8, DataType::Utf8View) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8, DataType::LargeUtf8) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::Utf8) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::Utf8View) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::LargeUtf8) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::Utf8View) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::Utf8) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::LargeUtf8) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
+ /// Dispatches the needle array to the correct string type and calls
+ /// `strpos_general` with the given haystack and result type.
+ macro_rules! dispatch_needle {
+ ($haystack:expr, $result_type:ty, $args:expr) => {
+ match $args[1].data_type() {
+ DataType::Utf8 => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string::<i32>(),
+ ),
+ DataType::LargeUtf8 => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string::<i64>(),
+ ),
+ DataType::Utf8View => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string_view(),
+ ),
+ other => exec_err!(
+ "Unsupported data type {other:?} for function strpos
needle"
+ ),
+ }
+ };
+ }
+ match args[0].data_type() {
+ DataType::Utf8 => dispatch_needle!(args[0].as_string::<i32>(),
Int32Type, args),
+ DataType::LargeUtf8 => {
+ dispatch_needle!(args[0].as_string::<i64>(), Int64Type, args)
+ }
+ DataType::Utf8View => dispatch_needle!(args[0].as_string_view(),
Int32Type, args),
other => {
- exec_err!("Unsupported data type combination {other:?} for
function strpos")
+ exec_err!("Unsupported data type {other:?} for function strpos
haystack")
}
}
}
/// Find `needle` in `haystack` using `memchr` to quickly skip to positions
-/// where the first byte matches, then verify the remaining bytes. Using
-/// string::find is slower because it has significant per-call overhead that
-/// `memchr` does not, and strpos is often invoked many times on short inputs.
-/// Returns a 1-based position, or 0 if not found.
-/// Both inputs must be ASCII-only.
-fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
+/// where the first byte matches, then verify the remaining bytes. Returns
+/// the 0-based byte offset of the match, or `None` if not found. An empty
+/// `needle` matches at offset 0.
+fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
let needle_len = needle.len();
+ let haystack_len = haystack.len();
+
+ if needle_len == 0 {
+ return Some(0);
+ }
+ if needle_len > haystack_len {
+ return None;
+ }
+
let first_byte = needle[0];
let mut offset = 0;
while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
let start = offset + pos;
if start + needle_len > haystack.len() {
- return 0;
+ return None;
}
if haystack[start..start + needle_len] == *needle {
- return start + 1;
+ return Some(start);
}
offset = start + 1;
}
- 0
+ None
}
-/// Returns starting index of specified substring within string, or zero if
it's not present. (Same as position(substring in string), but note the reversed
argument order.)
-/// strpos('high', 'ig') = 2
-/// The implementation uses UTF-8 code points as characters
-fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
- string_array: &V1,
- substring_array: &V2,
+/// Fallback strpos implementation for when both haystack and needle are
arrays.
+/// Building a new `memmem::Finder` for every row is too expensive; it is
faster
+/// to use `memchr::memchr`.
+fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>(
+ haystack_array: V1,
+ needle_array: V2,
) -> Result<ArrayRef>
where
- V1: StringArrayType<'a, Item = &'a str>,
- V2: StringArrayType<'a, Item = &'a str>,
+ V1: StringArrayType<'a, Item = &'a str> + Copy,
+ V2: StringArrayType<'a, Item = &'a str> + Copy,
{
- let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
- let string_iter = string_array.iter();
- let substring_iter = substring_array.iter();
-
- let result = string_iter
- .zip(substring_iter)
- .map(|(string, substring)| match (string, substring) {
- (Some(string), Some(substring)) => {
- if substring.is_empty() {
- return T::Native::from_usize(1);
+ let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii();
+ let haystack_iter = haystack_array.iter();
+ let needle_iter = needle_array.iter();
+
+ let result = haystack_iter
+ .zip(needle_iter)
+ .map(|(haystack, needle)| match (haystack, needle) {
+ (Some(haystack), Some(needle)) => {
+ let haystack_bytes = haystack.as_bytes();
+ let needle_bytes = needle.as_bytes();
+
+ match find_substring_bytes(haystack_bytes, needle_bytes) {
+ None => T::Native::from_usize(0),
+ Some(byte_offset) => {
+ if ascii_only {
+ T::Native::from_usize(byte_offset + 1)
+ } else {
+ // SAFETY: haystack_bytes is valid UTF-8
+ let prefix = unsafe {
+ std::str::from_utf8_unchecked(
+ &haystack_bytes[..byte_offset],
+ )
+ };
+ T::Native::from_usize(prefix.chars().count() + 1)
+ }
+ }
}
+ }
+ _ => None,
+ })
+ .collect::<PrimitiveArray<T>>();
- let substring_bytes = substring.as_bytes();
- let string_bytes = string.as_bytes();
+ Ok(Arc::new(result) as ArrayRef)
+}
- if substring_bytes.len() > string_bytes.len() {
- return T::Native::from_usize(0);
- }
+/// Fast-path strpos implementation for when the haystack is an array and the
+/// needle is a scalar. We can pre-build a `memmem::Finder` once and reuse it
+/// for every haystack row.
+fn strpos_scalar_needle(
+ haystack_array: &ArrayRef,
+ needle_scalar: &ScalarValue,
+) -> Result<ColumnarValue> {
+ let Some(needle_str) = needle_scalar.try_as_str() else {
+ return exec_err!(
+ "Unsupported data type {needle_scalar:?} for function strpos
needle"
Review Comment:
```suggestion
"Unsupported data type {needle_scalar:?} for function strpos
needle"
"Unsupported data type {:?} for function strpos needle",
needle_scalar.data_type()
```
`Int64(42)` vs just `Int64`
##########
datafusion/sqllogictest/test_files/scalar.slt:
##########
@@ -2212,6 +2212,104 @@ select strpos('joséésoj', arrow_cast(null, 'Utf8'));
----
NULL
+# strpos with array inputs
+statement ok
+CREATE TABLE strpos_table AS VALUES
+ ('alphabet', 'ph'),
+ ('hello world', 'world'),
+ ('hello world', 'xyz'),
+ ('hello world', ''),
+ ('josé', 'é'),
+ ('joséésoj', 'so'),
+ ('ДатаФусион', 'Фусион'),
+ ('数据融合📊🔥', '📊'),
+ ('数据融合📊🔥', '融合'),
+ (NULL, 'abc'),
+ ('hello', NULL);
Review Comment:
```suggestion
('hello', NULL),
('ab', 'abcd');
```
a test for haystack shorter than the needle
##########
datafusion/functions/benches/strpos.rs:
##########
@@ -18,178 +18,201 @@
use arrow::array::{StringArray, StringViewArray};
use arrow::datatypes::{DataType, Field};
use criterion::{Criterion, criterion_group, criterion_main};
+use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use rand::distr::Alphanumeric;
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
use std::hint::black_box;
-use std::str::Chars;
use std::sync::Arc;
-/// Returns a `Vec<ColumnarValue>` with two elements: a haystack array and a
-/// needle array. Each haystack is a random string of `str_len_chars`
-/// characters. Each needle is a random contiguous substring of its
-/// corresponding haystack (i.e., the needle is always present in the
haystack).
-/// Around `null_density` fraction of rows are null and `utf8_density` fraction
-/// contain non-ASCII characters; the remaining rows are ASCII-only.
-fn gen_string_array(
- n_rows: usize,
+#[rustfmt::skip]
+const UTF8_CORPUS: &[char] = &[
+ // Cyrillic (2 bytes each)
+ 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'К', 'Л', 'М', 'Н', 'О', 'П',
'Р', 'С',
+ 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Э', 'Ю', 'Я',
+ // CJK (3 bytes each)
+ '数', '据', '融', '合', '查', '询', '引', '擎', '优', '化', '执', '行', '计', '划',
+ '表', '达',
+ // Emoji (4 bytes each)
+ '📊', '🔥', '🚀', '⚡', '🎯', '💡', '🔧', '📈',
+];
+const N_ROWS: usize = 8192;
+
+/// Returns a random string of `len` characters. If `ascii` is true, the string
+/// is ASCII-only; otherwise it is drawn from `UTF8_CORPUS`.
+fn random_string(rng: &mut StdRng, len: usize, ascii: bool) -> String {
+ if ascii {
+ let value: Vec<u8> =
rng.sample_iter(&Alphanumeric).take(len).collect();
+ String::from_utf8(value).unwrap()
+ } else {
+ (0..len)
+ .map(|_| UTF8_CORPUS[rng.random_range(0..UTF8_CORPUS.len())])
+ .collect()
+ }
+}
+
+/// Wraps `strings` into either a `StringArray` or `StringViewArray`.
+fn to_columnar_value(
+ strings: Vec<Option<String>>,
+ is_string_view: bool,
+) -> ColumnarValue {
+ if is_string_view {
+ let arr: StringViewArray = strings.into_iter().collect();
+ ColumnarValue::Array(Arc::new(arr))
+ } else {
+ let arr: StringArray = strings.into_iter().collect();
+ ColumnarValue::Array(Arc::new(arr))
+ }
+}
+
+/// Returns haystack and needle, where both are arrays. Each needle is a
+/// contiguous substring of its corresponding haystack. Around `null_density`
+/// fraction of rows are null and `utf8_density` fraction contain non-ASCII
+/// characters.
+fn make_array_needle_args(
+ rng: &mut StdRng,
str_len_chars: usize,
null_density: f32,
utf8_density: f32,
- is_string_view: bool, // false -> StringArray, true -> StringViewArray
+ is_string_view: bool,
) -> Vec<ColumnarValue> {
- let mut rng = StdRng::seed_from_u64(42);
- let rng_ref = &mut rng;
-
- let utf8 = "DatafusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with
1~4 bytes
- let corpus_char_count = utf8.chars().count();
-
- let mut output_string_vec: Vec<Option<String>> =
Vec::with_capacity(n_rows);
- let mut output_sub_string_vec: Vec<Option<String>> =
Vec::with_capacity(n_rows);
- for _ in 0..n_rows {
- let rand_num = rng_ref.random::<f32>(); // [0.0, 1.0)
- if rand_num < null_density {
- output_sub_string_vec.push(None);
- output_string_vec.push(None);
- } else if rand_num < null_density + utf8_density {
- // Generate random UTF8 string
- let mut generated_string = String::with_capacity(str_len_chars);
- for _ in 0..str_len_chars {
- let idx = rng_ref.random_range(0..corpus_char_count);
- let char = utf8.chars().nth(idx).unwrap();
- generated_string.push(char);
- }
-
output_sub_string_vec.push(Some(random_substring(generated_string.chars())));
- output_string_vec.push(Some(generated_string));
+ let mut haystacks: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
+ let mut needles: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
+ for _ in 0..N_ROWS {
+ let r = rng.random::<f32>();
+ if r < null_density {
+ haystacks.push(None);
+ needles.push(None);
} else {
- // Generate random ASCII-only string
- let value = rng_ref
+ let ascii = r >= null_density + utf8_density;
+ let s = random_string(rng, str_len_chars, ascii);
+ needles.push(Some(random_substring(rng, &s)));
+ haystacks.push(Some(s));
+ }
+ }
+
+ vec![
+ to_columnar_value(haystacks, is_string_view),
+ to_columnar_value(needles, is_string_view),
+ ]
+}
+
+/// Returns haystack array with a fixed scalar needle inserted into each row.
+/// `utf8_density` fraction of rows contain non-ASCII characters.
+/// The needle must be ASCII.
+fn make_scalar_needle_args(
+ rng: &mut StdRng,
+ str_len_chars: usize,
+ needle: &str,
+ utf8_density: f32,
Review Comment:
What about `null_density` like for the arrays ?
##########
datafusion/functions/benches/strpos.rs:
##########
@@ -18,178 +18,201 @@
use arrow::array::{StringArray, StringViewArray};
use arrow::datatypes::{DataType, Field};
use criterion::{Criterion, criterion_group, criterion_main};
+use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use rand::distr::Alphanumeric;
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
use std::hint::black_box;
-use std::str::Chars;
use std::sync::Arc;
-/// Returns a `Vec<ColumnarValue>` with two elements: a haystack array and a
-/// needle array. Each haystack is a random string of `str_len_chars`
-/// characters. Each needle is a random contiguous substring of its
-/// corresponding haystack (i.e., the needle is always present in the
haystack).
-/// Around `null_density` fraction of rows are null and `utf8_density` fraction
-/// contain non-ASCII characters; the remaining rows are ASCII-only.
-fn gen_string_array(
- n_rows: usize,
+#[rustfmt::skip]
+const UTF8_CORPUS: &[char] = &[
+ // Cyrillic (2 bytes each)
+ 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'К', 'Л', 'М', 'Н', 'О', 'П',
'Р', 'С',
+ 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Э', 'Ю', 'Я',
+ // CJK (3 bytes each)
+ '数', '据', '融', '合', '查', '询', '引', '擎', '优', '化', '执', '行', '计', '划',
+ '表', '达',
+ // Emoji (4 bytes each)
+ '📊', '🔥', '🚀', '⚡', '🎯', '💡', '🔧', '📈',
+];
+const N_ROWS: usize = 8192;
+
+/// Returns a random string of `len` characters. If `ascii` is true, the string
+/// is ASCII-only; otherwise it is drawn from `UTF8_CORPUS`.
+fn random_string(rng: &mut StdRng, len: usize, ascii: bool) -> String {
+ if ascii {
+ let value: Vec<u8> =
rng.sample_iter(&Alphanumeric).take(len).collect();
+ String::from_utf8(value).unwrap()
+ } else {
+ (0..len)
+ .map(|_| UTF8_CORPUS[rng.random_range(0..UTF8_CORPUS.len())])
+ .collect()
+ }
+}
+
+/// Wraps `strings` into either a `StringArray` or `StringViewArray`.
+fn to_columnar_value(
+ strings: Vec<Option<String>>,
+ is_string_view: bool,
+) -> ColumnarValue {
+ if is_string_view {
+ let arr: StringViewArray = strings.into_iter().collect();
+ ColumnarValue::Array(Arc::new(arr))
+ } else {
+ let arr: StringArray = strings.into_iter().collect();
+ ColumnarValue::Array(Arc::new(arr))
+ }
+}
+
+/// Returns haystack and needle, where both are arrays. Each needle is a
+/// contiguous substring of its corresponding haystack. Around `null_density`
+/// fraction of rows are null and `utf8_density` fraction contain non-ASCII
+/// characters.
+fn make_array_needle_args(
+ rng: &mut StdRng,
str_len_chars: usize,
null_density: f32,
utf8_density: f32,
- is_string_view: bool, // false -> StringArray, true -> StringViewArray
+ is_string_view: bool,
) -> Vec<ColumnarValue> {
- let mut rng = StdRng::seed_from_u64(42);
- let rng_ref = &mut rng;
-
- let utf8 = "DatafusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with
1~4 bytes
- let corpus_char_count = utf8.chars().count();
-
- let mut output_string_vec: Vec<Option<String>> =
Vec::with_capacity(n_rows);
- let mut output_sub_string_vec: Vec<Option<String>> =
Vec::with_capacity(n_rows);
- for _ in 0..n_rows {
- let rand_num = rng_ref.random::<f32>(); // [0.0, 1.0)
- if rand_num < null_density {
- output_sub_string_vec.push(None);
- output_string_vec.push(None);
- } else if rand_num < null_density + utf8_density {
- // Generate random UTF8 string
- let mut generated_string = String::with_capacity(str_len_chars);
- for _ in 0..str_len_chars {
- let idx = rng_ref.random_range(0..corpus_char_count);
- let char = utf8.chars().nth(idx).unwrap();
- generated_string.push(char);
- }
-
output_sub_string_vec.push(Some(random_substring(generated_string.chars())));
- output_string_vec.push(Some(generated_string));
+ let mut haystacks: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
+ let mut needles: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
+ for _ in 0..N_ROWS {
+ let r = rng.random::<f32>();
+ if r < null_density {
+ haystacks.push(None);
+ needles.push(None);
} else {
- // Generate random ASCII-only string
- let value = rng_ref
+ let ascii = r >= null_density + utf8_density;
+ let s = random_string(rng, str_len_chars, ascii);
+ needles.push(Some(random_substring(rng, &s)));
+ haystacks.push(Some(s));
+ }
+ }
+
+ vec![
+ to_columnar_value(haystacks, is_string_view),
+ to_columnar_value(needles, is_string_view),
+ ]
+}
+
+/// Returns haystack array with a fixed scalar needle inserted into each row.
+/// `utf8_density` fraction of rows contain non-ASCII characters.
+/// The needle must be ASCII.
+fn make_scalar_needle_args(
+ rng: &mut StdRng,
+ str_len_chars: usize,
+ needle: &str,
+ utf8_density: f32,
+ is_string_view: bool,
+) -> Vec<ColumnarValue> {
+ let needle_len = needle.len();
+ assert!(
+ str_len_chars >= needle_len,
+ "str_len_chars must be >= needle length"
+ );
+
+ let mut haystacks: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
+ for _ in 0..N_ROWS {
+ let ascii = rng.random::<f32>() >= utf8_density;
+ if ascii {
+ let mut value: Vec<u8> = (&mut *rng)
.sample_iter(&Alphanumeric)
.take(str_len_chars)
.collect();
- let value = String::from_utf8(value).unwrap();
- output_sub_string_vec.push(Some(random_substring(value.chars())));
- output_string_vec.push(Some(value));
+ let pos = rng.random_range(0..=str_len_chars - needle_len);
+ value[pos..pos + needle_len].copy_from_slice(needle.as_bytes());
+ haystacks.push(Some(String::from_utf8(value).unwrap()));
+ } else {
+ let mut s = random_string(rng, str_len_chars, false);
+ let char_positions: Vec<usize> = s.char_indices().map(|(i, _)|
i).collect();
+ let insert_pos = if char_positions.len() > 1 {
+ char_positions[rng.random_range(0..char_positions.len())]
+ } else {
+ 0
+ };
+ s.insert_str(insert_pos, needle);
+ haystacks.push(Some(s));
}
}
- if is_string_view {
- let string_view_array: StringViewArray =
output_string_vec.into_iter().collect();
- let sub_string_view_array: StringViewArray =
- output_sub_string_vec.into_iter().collect();
- vec![
- ColumnarValue::Array(Arc::new(string_view_array)),
- ColumnarValue::Array(Arc::new(sub_string_view_array)),
- ]
- } else {
- let string_array: StringArray =
output_string_vec.clone().into_iter().collect();
- let sub_string_array: StringArray =
output_sub_string_vec.into_iter().collect();
- vec![
- ColumnarValue::Array(Arc::new(string_array)),
- ColumnarValue::Array(Arc::new(sub_string_array)),
- ]
- }
+ let needle_cv =
ColumnarValue::Scalar(ScalarValue::Utf8(Some(needle.to_string())));
+ vec![to_columnar_value(haystacks, is_string_view), needle_cv]
}
-fn random_substring(chars: Chars) -> String {
- // get the substring of a random length from the input string by byte unit
- let mut rng = StdRng::seed_from_u64(44);
- let count = chars.clone().count();
+/// Extracts a random contiguous substring from `s`.
+fn random_substring(rng: &mut StdRng, s: &str) -> String {
+ let count = s.chars().count();
let start = rng.random_range(0..count - 1);
Review Comment:
This would panic if `s` is an empty string
##########
datafusion/functions/src/unicode/strpos.rs:
##########
@@ -127,142 +135,201 @@ impl ScalarUDFImpl for StrposFunc {
}
fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
- match (args[0].data_type(), args[1].data_type()) {
- (DataType::Utf8, DataType::Utf8) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8, DataType::Utf8View) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8, DataType::LargeUtf8) => {
- let string_array = args[0].as_string::<i32>();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::Utf8) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::Utf8View) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::LargeUtf8, DataType::LargeUtf8) => {
- let string_array = args[0].as_string::<i64>();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int64Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::Utf8View) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string_view();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::Utf8) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string::<i32>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
- (DataType::Utf8View, DataType::LargeUtf8) => {
- let string_array = args[0].as_string_view();
- let substring_array = args[1].as_string::<i64>();
- calculate_strpos::<_, _, Int32Type>(&string_array,
&substring_array)
- }
+ /// Dispatches the needle array to the correct string type and calls
+ /// `strpos_general` with the given haystack and result type.
+ macro_rules! dispatch_needle {
+ ($haystack:expr, $result_type:ty, $args:expr) => {
+ match $args[1].data_type() {
+ DataType::Utf8 => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string::<i32>(),
+ ),
+ DataType::LargeUtf8 => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string::<i64>(),
+ ),
+ DataType::Utf8View => strpos_general::<_, _, $result_type>(
+ $haystack,
+ $args[1].as_string_view(),
+ ),
+ other => exec_err!(
+ "Unsupported data type {other:?} for function strpos
needle"
+ ),
+ }
+ };
+ }
+ match args[0].data_type() {
+ DataType::Utf8 => dispatch_needle!(args[0].as_string::<i32>(),
Int32Type, args),
+ DataType::LargeUtf8 => {
+ dispatch_needle!(args[0].as_string::<i64>(), Int64Type, args)
+ }
+ DataType::Utf8View => dispatch_needle!(args[0].as_string_view(),
Int32Type, args),
other => {
- exec_err!("Unsupported data type combination {other:?} for
function strpos")
+ exec_err!("Unsupported data type {other:?} for function strpos
haystack")
}
}
}
/// Find `needle` in `haystack` using `memchr` to quickly skip to positions
-/// where the first byte matches, then verify the remaining bytes. Using
-/// string::find is slower because it has significant per-call overhead that
-/// `memchr` does not, and strpos is often invoked many times on short inputs.
-/// Returns a 1-based position, or 0 if not found.
-/// Both inputs must be ASCII-only.
-fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
+/// where the first byte matches, then verify the remaining bytes. Returns
+/// the 0-based byte offset of the match, or `None` if not found. An empty
+/// `needle` matches at offset 0.
+fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
let needle_len = needle.len();
+ let haystack_len = haystack.len();
+
+ if needle_len == 0 {
+ return Some(0);
+ }
+ if needle_len > haystack_len {
+ return None;
+ }
+
let first_byte = needle[0];
let mut offset = 0;
while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
let start = offset + pos;
if start + needle_len > haystack.len() {
- return 0;
+ return None;
}
if haystack[start..start + needle_len] == *needle {
- return start + 1;
+ return Some(start);
}
offset = start + 1;
}
- 0
+ None
}
-/// Returns starting index of specified substring within string, or zero if
it's not present. (Same as position(substring in string), but note the reversed
argument order.)
-/// strpos('high', 'ig') = 2
-/// The implementation uses UTF-8 code points as characters
-fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
- string_array: &V1,
- substring_array: &V2,
+/// Fallback strpos implementation for when both haystack and needle are
arrays.
+/// Building a new `memmem::Finder` for every row is too expensive; it is
faster
+/// to use `memchr::memchr`.
+fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>(
+ haystack_array: V1,
+ needle_array: V2,
) -> Result<ArrayRef>
where
- V1: StringArrayType<'a, Item = &'a str>,
- V2: StringArrayType<'a, Item = &'a str>,
+ V1: StringArrayType<'a, Item = &'a str> + Copy,
+ V2: StringArrayType<'a, Item = &'a str> + Copy,
{
- let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
- let string_iter = string_array.iter();
- let substring_iter = substring_array.iter();
-
- let result = string_iter
- .zip(substring_iter)
- .map(|(string, substring)| match (string, substring) {
- (Some(string), Some(substring)) => {
- if substring.is_empty() {
- return T::Native::from_usize(1);
+ let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii();
+ let haystack_iter = haystack_array.iter();
+ let needle_iter = needle_array.iter();
+
+ let result = haystack_iter
+ .zip(needle_iter)
+ .map(|(haystack, needle)| match (haystack, needle) {
+ (Some(haystack), Some(needle)) => {
+ let haystack_bytes = haystack.as_bytes();
+ let needle_bytes = needle.as_bytes();
+
+ match find_substring_bytes(haystack_bytes, needle_bytes) {
+ None => T::Native::from_usize(0),
+ Some(byte_offset) => {
+ if ascii_only {
+ T::Native::from_usize(byte_offset + 1)
+ } else {
+ // SAFETY: haystack_bytes is valid UTF-8
+ let prefix = unsafe {
+ std::str::from_utf8_unchecked(
+ &haystack_bytes[..byte_offset],
+ )
+ };
+ T::Native::from_usize(prefix.chars().count() + 1)
+ }
+ }
}
+ }
+ _ => None,
+ })
+ .collect::<PrimitiveArray<T>>();
- let substring_bytes = substring.as_bytes();
- let string_bytes = string.as_bytes();
+ Ok(Arc::new(result) as ArrayRef)
+}
- if substring_bytes.len() > string_bytes.len() {
- return T::Native::from_usize(0);
- }
+/// Fast-path strpos implementation for when the haystack is an array and the
+/// needle is a scalar. We can pre-build a `memmem::Finder` once and reuse it
+/// for every haystack row.
+fn strpos_scalar_needle(
+ haystack_array: &ArrayRef,
+ needle_scalar: &ScalarValue,
+) -> Result<ColumnarValue> {
+ let Some(needle_str) = needle_scalar.try_as_str() else {
+ return exec_err!(
+ "Unsupported data type {needle_scalar:?} for function strpos
needle"
+ );
+ };
+
+ // Null needle => null result for every row
+ let Some(needle_str) = needle_str else {
+ return match haystack_array.data_type() {
+ DataType::LargeUtf8 => {
+ Ok(ColumnarValue::Array(Arc::new(
+
PrimitiveArray::<Int64Type>::new_null(haystack_array.len()),
+ )))
+ }
+ _ => Ok(ColumnarValue::Array(Arc::new(
+ PrimitiveArray::<Int32Type>::new_null(haystack_array.len()),
+ ))),
Review Comment:
```suggestion
DataType::Utf8 | DataType::Utf8View =>
Ok(ColumnarValue::Array(Arc::new(
PrimitiveArray::<Int32Type>::new_null(haystack_array.len()),
))),
other => exec_err!(
"Unsupported data type {other:?} for function strpos
haystack"
),
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]