This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 269a473fee Faster strpos() string function for ASCII-only case (#12401)
269a473fee is described below

commit 269a473fee9c1d4fe19efba41bc7fa682cc0e848
Author: Jax Liu <[email protected]>
AuthorDate: Wed Sep 18 00:40:00 2024 +0800

    Faster strpos() string function for ASCII-only case (#12401)
    
    * add strpos benchmark
    
    * add faster path for strpos in ascii-only case
    
    * clippy
    
    * compare substring first
    
    * cargo fmt
---
 datafusion/functions/Cargo.toml            |   5 +
 datafusion/functions/benches/strpos.rs     | 142 +++++++++++++++++++++++++++++
 datafusion/functions/src/unicode/strpos.rs |  65 +++++++++----
 3 files changed, 195 insertions(+), 17 deletions(-)

diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml
index 5b6dceaa42..ff1b926a9b 100644
--- a/datafusion/functions/Cargo.toml
+++ b/datafusion/functions/Cargo.toml
@@ -171,3 +171,8 @@ required-features = ["unicode_expressions"]
 harness = false
 name = "character_length"
 required-features = ["unicode_expressions"]
+
+[[bench]]
+harness = false
+name = "strpos"
+required-features = ["unicode_expressions"]
diff --git a/datafusion/functions/benches/strpos.rs 
b/datafusion/functions/benches/strpos.rs
new file mode 100644
index 0000000000..c78e698268
--- /dev/null
+++ b/datafusion/functions/benches/strpos.rs
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate criterion;
+
+use arrow::array::{StringArray, StringViewArray};
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use datafusion_expr::ColumnarValue;
+use rand::distributions::Alphanumeric;
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
+use std::str::Chars;
+use std::sync::Arc;
+
+/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with
+/// 4096 rows, each row containing a string with 128 random characters.
+/// around 10% of the rows are null, around 10% of the rows are non-ASCII.
+fn gen_string_array(
+    n_rows: usize,
+    str_len_chars: usize,
+    null_density: f32,
+    utf8_density: f32,
+    is_string_view: bool, // false -> StringArray, true -> StringViewArray
+) -> Vec<ColumnarValue> {
+    let mut rng = StdRng::seed_from_u64(42);
+    let rng_ref = &mut rng;
+
+    let utf8 = "DatafusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 
1~4 bytes
+    let corpus_char_count = utf8.chars().count();
+
+    let mut output_string_vec: Vec<Option<String>> = 
Vec::with_capacity(n_rows);
+    let mut output_sub_string_vec: Vec<Option<String>> = 
Vec::with_capacity(n_rows);
+    for _ in 0..n_rows {
+        let rand_num = rng_ref.gen::<f32>(); // [0.0, 1.0)
+        if rand_num < null_density {
+            output_sub_string_vec.push(None);
+            output_string_vec.push(None);
+        } else if rand_num < null_density + utf8_density {
+            // Generate random UTF8 string
+            let mut generated_string = String::with_capacity(str_len_chars);
+            for _ in 0..str_len_chars {
+                let idx = rng_ref.gen_range(0..corpus_char_count);
+                let char = utf8.chars().nth(idx).unwrap();
+                generated_string.push(char);
+            }
+            
output_sub_string_vec.push(Some(random_substring(generated_string.chars())));
+            output_string_vec.push(Some(generated_string));
+        } else {
+            // Generate random ASCII-only string
+            let value = rng_ref
+                .sample_iter(&Alphanumeric)
+                .take(str_len_chars)
+                .collect();
+            let value = String::from_utf8(value).unwrap();
+            output_sub_string_vec.push(Some(random_substring(value.chars())));
+            output_string_vec.push(Some(value));
+        }
+    }
+
+    if is_string_view {
+        let string_view_array: StringViewArray = 
output_string_vec.into_iter().collect();
+        let sub_string_view_array: StringViewArray =
+            output_sub_string_vec.into_iter().collect();
+        vec![
+            ColumnarValue::Array(Arc::new(string_view_array)),
+            ColumnarValue::Array(Arc::new(sub_string_view_array)),
+        ]
+    } else {
+        let string_array: StringArray = 
output_string_vec.clone().into_iter().collect();
+        let sub_string_array: StringArray = 
output_sub_string_vec.into_iter().collect();
+        vec![
+            ColumnarValue::Array(Arc::new(string_array)),
+            ColumnarValue::Array(Arc::new(sub_string_array)),
+        ]
+    }
+}
+
+fn random_substring(chars: Chars) -> String {
+    // get the substring of a random length from the input string by byte unit
+    let mut rng = StdRng::seed_from_u64(44);
+    let count = chars.clone().count();
+    let start = rng.gen_range(0..count - 1);
+    let end = rng.gen_range(start + 1..count);
+    chars
+        .enumerate()
+        .filter(|(i, _)| *i >= start && *i < end)
+        .map(|(_, c)| c)
+        .collect()
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+    // All benches are single batch run with 8192 rows
+    let strpos = datafusion_functions::unicode::strpos();
+
+    let n_rows = 8192;
+    for str_len in [8, 32, 128, 4096] {
+        // StringArray ASCII only
+        let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, 
false);
+        c.bench_function(
+            &format!("strpos_StringArray_ascii_str_len_{}", str_len),
+            |b| b.iter(|| black_box(strpos.invoke(&args_string_ascii))),
+        );
+
+        // StringArray UTF8
+        let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, 
false);
+        c.bench_function(
+            &format!("strpos_StringArray_utf8_str_len_{}", str_len),
+            |b| b.iter(|| black_box(strpos.invoke(&args_string_utf8))),
+        );
+
+        // StringViewArray ASCII only
+        let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 
0.0, true);
+        c.bench_function(
+            &format!("strpos_StringViewArray_ascii_str_len_{}", str_len),
+            |b| b.iter(|| black_box(strpos.invoke(&args_string_view_ascii))),
+        );
+
+        // StringViewArray UTF8
+        let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 
0.5, true);
+        c.bench_function(
+            &format!("strpos_StringViewArray_utf8_str_len_{}", str_len),
+            |b| b.iter(|| black_box(strpos.invoke(&args_string_view_utf8))),
+        );
+    }
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/functions/src/unicode/strpos.rs 
b/datafusion/functions/src/unicode/strpos.rs
index cf10b18ae3..6da67c8a27 100644
--- a/datafusion/functions/src/unicode/strpos.rs
+++ b/datafusion/functions/src/unicode/strpos.rs
@@ -18,17 +18,15 @@
 use std::any::Any;
 use std::sync::Arc;
 
-use arrow::array::{
-    ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, 
PrimitiveArray,
-};
+use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
 use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
 
+use crate::string::common::StringArrayType;
+use crate::utils::{make_scalar_function, utf8_to_int_type};
 use datafusion_common::{exec_err, Result};
 use datafusion_expr::TypeSignature::Exact;
 use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
 
-use crate::utils::{make_scalar_function, utf8_to_int_type};
-
 #[derive(Debug)]
 pub struct StrposFunc {
     signature: Signature,
@@ -140,24 +138,43 @@ fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
     substring_array: V2,
 ) -> Result<ArrayRef>
 where
-    V1: ArrayAccessor<Item = &'a str>,
-    V2: ArrayAccessor<Item = &'a str>,
+    V1: StringArrayType<'a, Item = &'a str>,
+    V2: StringArrayType<'a, Item = &'a str>,
 {
-    let string_iter = ArrayIter::new(string_array);
-    let substring_iter = ArrayIter::new(substring_array);
+    let ascii_only = substring_array.is_ascii() && string_array.is_ascii();
+    let string_iter = string_array.iter();
+    let substring_iter = substring_array.iter();
 
     let result = string_iter
         .zip(substring_iter)
         .map(|(string, substring)| match (string, substring) {
             (Some(string), Some(substring)) => {
-                // The `find` method returns the byte index of the substring.
-                // We count the number of chars up to that byte index.
-                T::Native::from_usize(
-                    string
-                        .find(substring)
-                        .map(|x| string[..x].chars().count() + 1)
-                        .unwrap_or(0),
-                )
+                // If only ASCII characters are present, we can use the slide 
window method to find
+                // the sub vector in the main vector. This is faster than 
string.find() method.
+                if ascii_only {
+                    // If the substring is empty, the result is 1.
+                    if substring.as_bytes().is_empty() {
+                        T::Native::from_usize(1)
+                    } else {
+                        T::Native::from_usize(
+                            string
+                                .as_bytes()
+                                .windows(substring.as_bytes().len())
+                                .position(|w| w == substring.as_bytes())
+                                .map(|x| x + 1)
+                                .unwrap_or(0),
+                        )
+                    }
+                } else {
+                    // The `find` method returns the byte index of the 
substring.
+                    // We count the number of chars up to that byte index.
+                    T::Native::from_usize(
+                        string
+                            .find(substring)
+                            .map(|x| string[..x].chars().count() + 1)
+                            .unwrap_or(0),
+                    )
+                }
             }
             _ => None,
         })
@@ -201,6 +218,8 @@ mod tests {
         test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
         test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
         test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
+        test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
+        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 
Int32Array);
 
         // LargeUtf8 and LargeUtf8 combinations
         test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 
Int64Array);
@@ -208,6 +227,8 @@ mod tests {
         test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 
Int64Array);
         test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 
Int64Array);
         test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
+        test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
+        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 
Int64 Int64Array);
 
         // Utf8 and LargeUtf8 combinations
         test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 
Int32Array);
@@ -215,6 +236,8 @@ mod tests {
         test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 
Int32Array);
         test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
         test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
+        test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
+        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 
Int32Array);
 
         // LargeUtf8 and Utf8 combinations
         test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 
Int64Array);
@@ -222,6 +245,8 @@ mod tests {
         test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 
Int64Array);
         test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
         test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
+        test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
+        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 
Int64Array);
 
         // Utf8View and Utf8View combinations
         test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 
Int32Array);
@@ -229,6 +254,8 @@ mod tests {
         test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 
Int32Array);
         test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 
Int32Array);
         test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
+        test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
+        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 
Int32 Int32Array);
 
         // Utf8View and Utf8 combinations
         test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 
Int32Array);
@@ -236,6 +263,8 @@ mod tests {
         test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
         test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
         test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
+        test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
+        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 
Int32Array);
 
         // Utf8View and LargeUtf8 combinations
         test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 
Int32Array);
@@ -243,5 +272,7 @@ mod tests {
         test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 
Int32Array);
         test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 
Int32Array);
         test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
+        test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
+        test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 
Int32 Int32Array);
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to