This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 483663bc79 feat: optimize `lower` and `upper` functions (#9971)
483663bc79 is described below

commit 483663bc7956382e1c3d69bbb758424c126e4dde
Author: JasonLi <[email protected]>
AuthorDate: Mon Apr 15 18:22:09 2024 +0800

    feat: optimize `lower` and `upper` functions (#9971)
    
    * feat: optimize lower and upper functions
    
    * chore: pass cargo check
    
    * chore: pass cargo clippy
    
    * fix: lower and upper bug
    
    * optimize
    
    * using iter to find the first nonascii
    
    * chore: rename function
    
    * refactor: case_conversion_array function
    
    * refactor: remove !string_array.is_nullable() from case_conversion_array
---
 datafusion/functions/Cargo.toml           |  11 +++
 datafusion/functions/benches/lower.rs     |  91 ++++++++++++++++++++
 datafusion/functions/benches/upper.rs     |  46 ++++++++++
 datafusion/functions/src/string/common.rs | 137 ++++++++++++++++++------------
 datafusion/functions/src/string/lower.rs  |  97 ++++++++++++++++++++-
 datafusion/functions/src/string/upper.rs  |  97 ++++++++++++++++++++-
 6 files changed, 420 insertions(+), 59 deletions(-)

diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml
index b3ba809687..cf15b490b6 100644
--- a/datafusion/functions/Cargo.toml
+++ b/datafusion/functions/Cargo.toml
@@ -83,6 +83,7 @@ unicode-segmentation = { version = "^1.7.1", optional = true }
 uuid = { version = "1.7", features = ["v4"], optional = true }
 
 [dev-dependencies]
+arrow = { workspace = true, features = ["test_utils"] }
 criterion = "0.5"
 rand = { workspace = true }
 rstest = { workspace = true }
@@ -117,3 +118,13 @@ required-features = ["unicode_expressions"]
 harness = false
 name = "ltrim"
 required-features = ["string_expressions"]
+
+[[bench]]
+harness = false
+name = "lower"
+required-features = ["string_expressions"]
+
+[[bench]]
+harness = false
+name = "upper"
+required-features = ["string_expressions"]
diff --git a/datafusion/functions/benches/lower.rs 
b/datafusion/functions/benches/lower.rs
new file mode 100644
index 0000000000..fa963f174e
--- /dev/null
+++ b/datafusion/functions/benches/lower.rs
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate criterion;
+
+use arrow::array::{ArrayRef, StringArray};
+use arrow::util::bench_util::create_string_array_with_len;
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use datafusion_expr::ColumnarValue;
+use datafusion_functions::string;
+use std::sync::Arc;
+
+/// Create an array of args containing a StringArray, where all the values in 
the
+/// StringArray are ASCII.
+/// * `size` - the length of the StringArray, and
+/// * `str_len` - the length of the strings within the StringArray.
+fn create_args1(size: usize, str_len: usize) -> Vec<ColumnarValue> {
+    let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, 
str_len));
+    vec![ColumnarValue::Array(array)]
+}
+
+/// Create an array of args containing a StringArray, where the first value in 
the
+/// StringArray is non-ASCII.
+/// * `size` - the length of the StringArray, and
+/// * `str_len` - the length of the strings within the StringArray.
+fn create_args2(size: usize) -> Vec<ColumnarValue> {
+    let mut items = Vec::with_capacity(size);
+    items.push("农历新年".to_string());
+    for i in 1..size {
+        items.push(format!("DATAFUSION {}", i));
+    }
+    let array = Arc::new(StringArray::from(items)) as ArrayRef;
+    vec![ColumnarValue::Array(array)]
+}
+
+/// Create an array of args containing a StringArray, where the middle value 
of the
+/// StringArray is non-ASCII.
+/// * `size` - the length of the StringArray, and
+/// * `str_len` - the length of the strings within the StringArray.
+fn create_args3(size: usize) -> Vec<ColumnarValue> {
+    let mut items = Vec::with_capacity(size);
+    let half = size / 2;
+    for i in 0..half {
+        items.push(format!("DATAFUSION {}", i));
+    }
+    items.push("Ⱦ".to_string());
+    for i in half + 1..size {
+        items.push(format!("DATAFUSION {}", i));
+    }
+    let array = Arc::new(StringArray::from(items)) as ArrayRef;
+    vec![ColumnarValue::Array(array)]
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let lower = string::lower();
+    for size in [1024, 4096, 8192] {
+        let args = create_args1(size, 32);
+        c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| 
{
+            b.iter(|| black_box(lower.invoke(&args)))
+        });
+
+        let args = create_args2(size);
+        c.bench_function(
+            &format!("lower_the_first_value_is_nonascii: {}", size),
+            |b| b.iter(|| black_box(lower.invoke(&args))),
+        );
+
+        let args = create_args3(size);
+        c.bench_function(
+            &format!("lower_the_middle_value_is_nonascii: {}", size),
+            |b| b.iter(|| black_box(lower.invoke(&args))),
+        );
+    }
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/functions/benches/upper.rs 
b/datafusion/functions/benches/upper.rs
new file mode 100644
index 0000000000..a3e5fbd7a4
--- /dev/null
+++ b/datafusion/functions/benches/upper.rs
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate criterion;
+
+use arrow::util::bench_util::create_string_array_with_len;
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+use datafusion_expr::ColumnarValue;
+use datafusion_functions::string;
+use std::sync::Arc;
+
+/// Create an array of args containing a StringArray, where all the values in 
the
+/// StringArray are ASCII.
+/// * `size` - the length of the StringArray, and
+/// * `str_len` - the length of the strings within the StringArray.
+fn create_args(size: usize, str_len: usize) -> Vec<ColumnarValue> {
+    let array = Arc::new(create_string_array_with_len::<i32>(size, 0.2, 
str_len));
+    vec![ColumnarValue::Array(array)]
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let upper = string::upper();
+    for size in [1024, 4096, 8192] {
+        let args = create_args(size, 32);
+        c.bench_function("upper_all_values_are_ascii", |b| {
+            b.iter(|| black_box(upper.invoke(&args)))
+        });
+    }
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/functions/src/string/common.rs 
b/datafusion/functions/src/string/common.rs
index 2b554db397..97f9e1d93b 100644
--- a/datafusion/functions/src/string/common.rs
+++ b/datafusion/functions/src/string/common.rs
@@ -19,8 +19,10 @@ use std::fmt::{Display, Formatter};
 use std::sync::Arc;
 
 use arrow::array::{
-    new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait,
+    new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder,
+    OffsetSizeTrait,
 };
+use arrow::buffer::Buffer;
 use arrow::datatypes::DataType;
 
 use datafusion_common::cast::as_generic_string_array;
@@ -112,80 +114,105 @@ pub(crate) fn general_trim<T: OffsetSizeTrait>(
     }
 }
 
-/// applies a unary expression to `args[0]` that is expected to be 
downcastable to
-/// a `GenericStringArray` and returns a `GenericStringArray` (which may have 
a different offset)
-/// # Errors
-/// This function errors when:
-/// * the number of arguments is not 1
-/// * the first argument is not castable to a `GenericStringArray`
-pub(crate) fn unary_string_function<'a, T, O, F, R>(
-    args: &[&'a dyn Array],
-    op: F,
-    name: &str,
-) -> Result<GenericStringArray<O>>
-where
-    R: AsRef<str>,
-    O: OffsetSizeTrait,
-    T: OffsetSizeTrait,
-    F: Fn(&'a str) -> R,
-{
-    if args.len() != 1 {
-        return exec_err!(
-            "{:?} args were supplied but {} takes exactly one argument",
-            args.len(),
-            name
-        );
-    }
-
-    let string_array = as_generic_string_array::<T>(args[0])?;
+pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> 
Result<ColumnarValue> {
+    case_conversion(args, |string| string.to_lowercase(), name)
+}
 
-    // first map is the iterator, second is for the `Option<_>`
-    Ok(string_array.iter().map(|string| string.map(&op)).collect())
+pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> 
Result<ColumnarValue> {
+    case_conversion(args, |string| string.to_uppercase(), name)
 }
 
-pub(crate) fn handle<'a, F, R>(
+fn case_conversion<'a, F>(
     args: &'a [ColumnarValue],
     op: F,
     name: &str,
 ) -> Result<ColumnarValue>
 where
-    R: AsRef<str>,
-    F: Fn(&'a str) -> R,
+    F: Fn(&'a str) -> String,
 {
     match &args[0] {
-        ColumnarValue::Array(a) => match a.data_type() {
-            DataType::Utf8 => {
-                Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
-                    i32,
-                    i32,
-                    _,
-                    _,
-                >(
-                    &[a.as_ref()], op, name
-                )?)))
-            }
-            DataType::LargeUtf8 => {
-                Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
-                    i64,
-                    i64,
-                    _,
-                    _,
-                >(
-                    &[a.as_ref()], op, name
-                )?)))
-            }
+        ColumnarValue::Array(array) => match array.data_type() {
+            DataType::Utf8 => 
Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
+                array, op,
+            )?)),
+            DataType::LargeUtf8 => 
Ok(ColumnarValue::Array(case_conversion_array::<
+                i64,
+                _,
+            >(array, op)?)),
             other => exec_err!("Unsupported data type {other:?} for function 
{name}"),
         },
         ColumnarValue::Scalar(scalar) => match scalar {
             ScalarValue::Utf8(a) => {
-                let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
+                let result = a.as_ref().map(|x| op(x));
                 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
             }
             ScalarValue::LargeUtf8(a) => {
-                let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
+                let result = a.as_ref().map(|x| op(x));
                 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
             }
             other => exec_err!("Unsupported data type {other:?} for function 
{name}"),
         },
     }
 }
+
+fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> 
Result<ArrayRef>
+where
+    O: OffsetSizeTrait,
+    F: Fn(&'a str) -> String,
+{
+    const PRE_ALLOC_BYTES: usize = 8;
+
+    let string_array = as_generic_string_array::<O>(array)?;
+    let value_data = string_array.value_data();
+
+    // All values are ASCII.
+    if value_data.is_ascii() {
+        return case_conversion_ascii_array::<O, _>(string_array, op);
+    }
+
+    // Values contain non-ASCII.
+    let item_len = string_array.len();
+    let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
+    let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, 
capacity);
+
+    if string_array.null_count() == 0 {
+        let iter =
+            (0..item_len).map(|i| Some(op(unsafe { 
string_array.value_unchecked(i) })));
+        builder.extend(iter);
+    } else {
+        let iter = string_array.iter().map(|string| string.map(&op));
+        builder.extend(iter);
+    }
+    Ok(Arc::new(builder.finish()))
+}
+
+/// All values of string_array are ASCII, and when converting case, there is 
no changes in the byte
+/// array length. Therefore, the StringArray can be treated as a complete 
ASCII string for
+/// case conversion, and we can reuse the offsets buffer and the nulls buffer.
+fn case_conversion_ascii_array<'a, O, F>(
+    string_array: &'a GenericStringArray<O>,
+    op: F,
+) -> Result<ArrayRef>
+where
+    O: OffsetSizeTrait,
+    F: Fn(&'a str) -> String,
+{
+    let value_data = string_array.value_data();
+    // SAFETY: all items stored in value_data satisfy UTF8.
+    // ref: impl ByteArrayNativeType for str {...}
+    let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };
+
+    // conversion
+    let converted_values = op(str_values);
+    assert_eq!(converted_values.len(), str_values.len());
+    let bytes = converted_values.into_bytes();
+
+    // build result
+    let values = Buffer::from_vec(bytes);
+    let offsets = string_array.offsets().clone();
+    let nulls = string_array.nulls().cloned();
+    // SAFETY: offsets and nulls are consistent with the input array.
+    Ok(Arc::new(unsafe {
+        GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
+    }))
+}
diff --git a/datafusion/functions/src/string/lower.rs 
b/datafusion/functions/src/string/lower.rs
index a1eff70422..b9b3840252 100644
--- a/datafusion/functions/src/string/lower.rs
+++ b/datafusion/functions/src/string/lower.rs
@@ -23,7 +23,7 @@ use datafusion_common::Result;
 use datafusion_expr::ColumnarValue;
 use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
 
-use crate::string::common::handle;
+use crate::string::common::to_lower;
 use crate::utils::utf8_to_str_type;
 
 #[derive(Debug)]
@@ -62,6 +62,99 @@ impl ScalarUDFImpl for LowerFunc {
     }
 
     fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
-        handle(args, |string| string.to_lowercase(), "lower")
+        to_lower(args, "lower")
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::array::{ArrayRef, StringArray};
+    use std::sync::Arc;
+
+    fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> {
+        let func = LowerFunc::new();
+        let args = vec![ColumnarValue::Array(input)];
+        let result = match func.invoke(&args)? {
+            ColumnarValue::Array(result) => result,
+            _ => unreachable!(),
+        };
+        assert_eq!(&expected, &result);
+        Ok(())
+    }
+
+    #[test]
+    fn lower_maybe_optimization() -> Result<()> {
+        let input = Arc::new(StringArray::from(vec![
+            Some("农历新年"),
+            None,
+            Some("DATAFUSION"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        let expected = Arc::new(StringArray::from(vec![
+            Some("农历新年"),
+            None,
+            Some("datafusion"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        to_lower(input, expected)
+    }
+
+    #[test]
+    fn lower_full_optimization() -> Result<()> {
+        let input = Arc::new(StringArray::from(vec![
+            Some("ARROW"),
+            None,
+            Some("DATAFUSION"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        let expected = Arc::new(StringArray::from(vec![
+            Some("arrow"),
+            None,
+            Some("datafusion"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        to_lower(input, expected)
+    }
+
+    #[test]
+    fn lower_partial_optimization() -> Result<()> {
+        let input = Arc::new(StringArray::from(vec![
+            Some("ARROW"),
+            None,
+            Some("DATAFUSION"),
+            Some("@_"),
+            Some("0123456789"),
+            Some(""),
+            Some("\t\n"),
+            Some("ὈΔΥΣΣΕΎΣ"),
+            Some("TSCHÜSS"),
+            Some("Ⱦ"), // ⱦ: length change
+            Some("农历新年"),
+        ])) as ArrayRef;
+
+        let expected = Arc::new(StringArray::from(vec![
+            Some("arrow"),
+            None,
+            Some("datafusion"),
+            Some("@_"),
+            Some("0123456789"),
+            Some(""),
+            Some("\t\n"),
+            Some("ὀδυσσεύς"),
+            Some("tschüss"),
+            Some("ⱦ"),
+            Some("农历新年"),
+        ])) as ArrayRef;
+
+        to_lower(input, expected)
     }
 }
diff --git a/datafusion/functions/src/string/upper.rs 
b/datafusion/functions/src/string/upper.rs
index c21824d30d..8f03d7dc6b 100644
--- a/datafusion/functions/src/string/upper.rs
+++ b/datafusion/functions/src/string/upper.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::string::common::handle;
+use crate::string::common::to_upper;
 use crate::utils::utf8_to_str_type;
 use arrow::datatypes::DataType;
 use datafusion_common::Result;
@@ -59,6 +59,99 @@ impl ScalarUDFImpl for UpperFunc {
     }
 
     fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
-        handle(args, |string| string.to_uppercase(), "upper")
+        to_upper(args, "upper")
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::array::{ArrayRef, StringArray};
+    use std::sync::Arc;
+
+    fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> {
+        let func = UpperFunc::new();
+        let args = vec![ColumnarValue::Array(input)];
+        let result = match func.invoke(&args)? {
+            ColumnarValue::Array(result) => result,
+            _ => unreachable!(),
+        };
+        assert_eq!(&expected, &result);
+        Ok(())
+    }
+
+    #[test]
+    fn upper_maybe_optimization() -> Result<()> {
+        let input = Arc::new(StringArray::from(vec![
+            Some("农历新年"),
+            None,
+            Some("datafusion"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        let expected = Arc::new(StringArray::from(vec![
+            Some("农历新年"),
+            None,
+            Some("DATAFUSION"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        to_upper(input, expected)
+    }
+
+    #[test]
+    fn upper_full_optimization() -> Result<()> {
+        let input = Arc::new(StringArray::from(vec![
+            Some("arrow"),
+            None,
+            Some("datafusion"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        let expected = Arc::new(StringArray::from(vec![
+            Some("ARROW"),
+            None,
+            Some("DATAFUSION"),
+            Some("0123456789"),
+            Some(""),
+        ])) as ArrayRef;
+
+        to_upper(input, expected)
+    }
+
+    #[test]
+    fn upper_partial_optimization() -> Result<()> {
+        let input = Arc::new(StringArray::from(vec![
+            Some("arrow"),
+            None,
+            Some("datafusion"),
+            Some("@_"),
+            Some("0123456789"),
+            Some(""),
+            Some("\t\n"),
+            Some("ὀδυσσεύς"),
+            Some("tschüß"),
+            Some("ⱦ"), // Ⱦ: length change
+            Some("农历新年"),
+        ])) as ArrayRef;
+
+        let expected = Arc::new(StringArray::from(vec![
+            Some("ARROW"),
+            None,
+            Some("DATAFUSION"),
+            Some("@_"),
+            Some("0123456789"),
+            Some(""),
+            Some("\t\n"),
+            Some("ὈΔΥΣΣΕΎΣ"),
+            Some("TSCHÜSS"),
+            Some("Ⱦ"),
+            Some("农历新年"),
+        ])) as ArrayRef;
+
+        to_upper(input, expected)
     }
 }

Reply via email to