This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8c9e567822 Optimize performance of substr_index and add tests (#9973)
8c9e567822 is described below

commit 8c9e5678228557aff370b137e9029462230df68a
Author: Kevin Mingtarja <[email protected]>
AuthorDate: Tue Apr 9 00:01:26 2024 +0800

    Optimize performance of substr_index and add tests (#9973)
    
    * Optimize performance of substr_index
---
 datafusion/functions/src/unicode/substrindex.rs  | 153 ++++++++++++++++++++---
 datafusion/sqllogictest/test_files/functions.slt |  11 +-
 2 files changed, 143 insertions(+), 21 deletions(-)

diff --git a/datafusion/functions/src/unicode/substrindex.rs 
b/datafusion/functions/src/unicode/substrindex.rs
index d00108a68f..da4ff55828 100644
--- a/datafusion/functions/src/unicode/substrindex.rs
+++ b/datafusion/functions/src/unicode/substrindex.rs
@@ -18,7 +18,7 @@
 use std::any::Any;
 use std::sync::Arc;
 
-use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
+use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder};
 use arrow::datatypes::DataType;
 
 use datafusion_common::cast::{as_generic_string_array, as_int64_array};
@@ -101,38 +101,151 @@ pub fn substr_index<T: OffsetSizeTrait>(args: 
&[ArrayRef]) -> Result<ArrayRef> {
     let delimiter_array = as_generic_string_array::<T>(&args[1])?;
     let count_array = as_int64_array(&args[2])?;
 
-    let result = string_array
+    let mut builder = StringBuilder::new();
+    string_array
         .iter()
         .zip(delimiter_array.iter())
         .zip(count_array.iter())
-        .map(|((string, delimiter), n)| match (string, delimiter, n) {
+        .for_each(|((string, delimiter), n)| match (string, delimiter, n) {
             (Some(string), Some(delimiter), Some(n)) => {
                 // In MySQL, these cases will return an empty string.
                 if n == 0 || string.is_empty() || delimiter.is_empty() {
-                    return Some(String::new());
+                    builder.append_value("");
+                    return;
                 }
 
-                let splitted: Box<dyn Iterator<Item = _>> = if n > 0 {
-                    Box::new(string.split(delimiter))
+                let occurrences = 
usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
+                let length = if n > 0 {
+                    let splitted = string.split(delimiter);
+                    splitted
+                        .take(occurrences)
+                        .map(|s| s.len() + delimiter.len())
+                        .sum::<usize>()
+                        - delimiter.len()
                 } else {
-                    Box::new(string.rsplit(delimiter))
+                    let splitted = string.rsplit(delimiter);
+                    splitted
+                        .take(occurrences)
+                        .map(|s| s.len() + delimiter.len())
+                        .sum::<usize>()
+                        - delimiter.len()
                 };
-                let occurrences = 
usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
-                // The length of the substring covered by substr_index.
-                let length = splitted
-                    .take(occurrences) // at least 1 element, since n != 0
-                    .map(|s| s.len() + delimiter.len())
-                    .sum::<usize>()
-                    - delimiter.len();
                 if n > 0 {
-                    Some(string[..length].to_owned())
+                    match string.get(..length) {
+                        Some(substring) => builder.append_value(substring),
+                        None => builder.append_null(),
+                    }
                 } else {
-                    Some(string[string.len() - length..].to_owned())
+                    match string.get(string.len().saturating_sub(length)..) {
+                        Some(substring) => builder.append_value(substring),
+                        None => builder.append_null(),
+                    }
                 }
             }
-            _ => None,
-        })
-        .collect::<GenericStringArray<T>>();
+            _ => builder.append_null(),
+        });
+
+    Ok(Arc::new(builder.finish()) as ArrayRef)
+}
 
-    Ok(Arc::new(result) as ArrayRef)
+#[cfg(test)]
+mod tests {
+    use arrow::array::{Array, StringArray};
+    use arrow::datatypes::DataType::Utf8;
+
+    use datafusion_common::{Result, ScalarValue};
+    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
+
+    use crate::unicode::substrindex::SubstrIndexFunc;
+    use crate::utils::test::test_function;
+
+    #[test]
+    fn test_functions() -> Result<()> {
+        test_function!(
+            SubstrIndexFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+                ColumnarValue::Scalar(ScalarValue::from(".")),
+                ColumnarValue::Scalar(ScalarValue::from(1i64)),
+            ],
+            Ok(Some("www")),
+            &str,
+            Utf8,
+            StringArray
+        );
+        test_function!(
+            SubstrIndexFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+                ColumnarValue::Scalar(ScalarValue::from(".")),
+                ColumnarValue::Scalar(ScalarValue::from(2i64)),
+            ],
+            Ok(Some("www.apache")),
+            &str,
+            Utf8,
+            StringArray
+        );
+        test_function!(
+            SubstrIndexFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+                ColumnarValue::Scalar(ScalarValue::from(".")),
+                ColumnarValue::Scalar(ScalarValue::from(-2i64)),
+            ],
+            Ok(Some("apache.org")),
+            &str,
+            Utf8,
+            StringArray
+        );
+        test_function!(
+            SubstrIndexFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+                ColumnarValue::Scalar(ScalarValue::from(".")),
+                ColumnarValue::Scalar(ScalarValue::from(-1i64)),
+            ],
+            Ok(Some("org")),
+            &str,
+            Utf8,
+            StringArray
+        );
+        test_function!(
+            SubstrIndexFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+                ColumnarValue::Scalar(ScalarValue::from(".")),
+                ColumnarValue::Scalar(ScalarValue::from(0i64)),
+            ],
+            Ok(Some("")),
+            &str,
+            Utf8,
+            StringArray
+        );
+        test_function!(
+            SubstrIndexFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("")),
+                ColumnarValue::Scalar(ScalarValue::from(".")),
+                ColumnarValue::Scalar(ScalarValue::from(1i64)),
+            ],
+            Ok(Some("")),
+            &str,
+            Utf8,
+            StringArray
+        );
+        test_function!(
+            SubstrIndexFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+                ColumnarValue::Scalar(ScalarValue::from("")),
+                ColumnarValue::Scalar(ScalarValue::from(1i64)),
+            ],
+            Ok(Some("")),
+            &str,
+            Utf8,
+            StringArray
+        );
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/functions.slt 
b/datafusion/sqllogictest/test_files/functions.slt
index 21433ba168..38ebedf565 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -940,7 +940,8 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM
   (VALUES
     ROW('arrow.apache.org'),
     ROW('.'),
-    ROW('...')
+    ROW('...'),
+    ROW(NULL)
   ) AS strings(str),
   (VALUES
     ROW(1),
@@ -954,6 +955,14 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM
   ) AS occurrences(n)
 ORDER BY str DESC, n;
 ----
+NULL -100 NULL
+NULL -3 NULL
+NULL -2 NULL
+NULL -1 NULL
+NULL 1 NULL
+NULL 2 NULL
+NULL 3 NULL
+NULL 100 NULL
 arrow.apache.org -100 arrow.apache.org
 arrow.apache.org -3 arrow.apache.org
 arrow.apache.org -2 apache.org

Reply via email to