This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8c9e567822 Optimize performance of substr_index and add tests (#9973)
8c9e567822 is described below
commit 8c9e5678228557aff370b137e9029462230df68a
Author: Kevin Mingtarja <[email protected]>
AuthorDate: Tue Apr 9 00:01:26 2024 +0800
Optimize performance of substr_index and add tests (#9973)
* Optimize performance of substr_index
---
datafusion/functions/src/unicode/substrindex.rs | 153 ++++++++++++++++++++---
datafusion/sqllogictest/test_files/functions.slt | 11 +-
2 files changed, 143 insertions(+), 21 deletions(-)
diff --git a/datafusion/functions/src/unicode/substrindex.rs
b/datafusion/functions/src/unicode/substrindex.rs
index d00108a68f..da4ff55828 100644
--- a/datafusion/functions/src/unicode/substrindex.rs
+++ b/datafusion/functions/src/unicode/substrindex.rs
@@ -18,7 +18,7 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
+use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder};
use arrow::datatypes::DataType;
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
@@ -101,38 +101,151 @@ pub fn substr_index<T: OffsetSizeTrait>(args:
&[ArrayRef]) -> Result<ArrayRef> {
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
let count_array = as_int64_array(&args[2])?;
- let result = string_array
+ let mut builder = StringBuilder::new();
+ string_array
.iter()
.zip(delimiter_array.iter())
.zip(count_array.iter())
- .map(|((string, delimiter), n)| match (string, delimiter, n) {
+ .for_each(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
// In MySQL, these cases will return an empty string.
if n == 0 || string.is_empty() || delimiter.is_empty() {
- return Some(String::new());
+ builder.append_value("");
+ return;
}
- let splitted: Box<dyn Iterator<Item = _>> = if n > 0 {
- Box::new(string.split(delimiter))
+ let occurrences =
usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
+ let length = if n > 0 {
+ let splitted = string.split(delimiter);
+ splitted
+ .take(occurrences)
+ .map(|s| s.len() + delimiter.len())
+ .sum::<usize>()
+ - delimiter.len()
} else {
- Box::new(string.rsplit(delimiter))
+ let splitted = string.rsplit(delimiter);
+ splitted
+ .take(occurrences)
+ .map(|s| s.len() + delimiter.len())
+ .sum::<usize>()
+ - delimiter.len()
};
- let occurrences =
usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
- // The length of the substring covered by substr_index.
- let length = splitted
- .take(occurrences) // at least 1 element, since n != 0
- .map(|s| s.len() + delimiter.len())
- .sum::<usize>()
- - delimiter.len();
if n > 0 {
- Some(string[..length].to_owned())
+ match string.get(..length) {
+ Some(substring) => builder.append_value(substring),
+ None => builder.append_null(),
+ }
} else {
- Some(string[string.len() - length..].to_owned())
+ match string.get(string.len().saturating_sub(length)..) {
+ Some(substring) => builder.append_value(substring),
+ None => builder.append_null(),
+ }
}
}
- _ => None,
- })
- .collect::<GenericStringArray<T>>();
+ _ => builder.append_null(),
+ });
+
+ Ok(Arc::new(builder.finish()) as ArrayRef)
+}
- Ok(Arc::new(result) as ArrayRef)
+#[cfg(test)]
+mod tests {
+ use arrow::array::{Array, StringArray};
+ use arrow::datatypes::DataType::Utf8;
+
+ use datafusion_common::{Result, ScalarValue};
+ use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
+
+ use crate::unicode::substrindex::SubstrIndexFunc;
+ use crate::utils::test::test_function;
+
+ #[test]
+ fn test_functions() -> Result<()> {
+ test_function!(
+ SubstrIndexFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+ ColumnarValue::Scalar(ScalarValue::from(".")),
+ ColumnarValue::Scalar(ScalarValue::from(1i64)),
+ ],
+ Ok(Some("www")),
+ &str,
+ Utf8,
+ StringArray
+ );
+ test_function!(
+ SubstrIndexFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+ ColumnarValue::Scalar(ScalarValue::from(".")),
+ ColumnarValue::Scalar(ScalarValue::from(2i64)),
+ ],
+ Ok(Some("www.apache")),
+ &str,
+ Utf8,
+ StringArray
+ );
+ test_function!(
+ SubstrIndexFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+ ColumnarValue::Scalar(ScalarValue::from(".")),
+ ColumnarValue::Scalar(ScalarValue::from(-2i64)),
+ ],
+ Ok(Some("apache.org")),
+ &str,
+ Utf8,
+ StringArray
+ );
+ test_function!(
+ SubstrIndexFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+ ColumnarValue::Scalar(ScalarValue::from(".")),
+ ColumnarValue::Scalar(ScalarValue::from(-1i64)),
+ ],
+ Ok(Some("org")),
+ &str,
+ Utf8,
+ StringArray
+ );
+ test_function!(
+ SubstrIndexFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+ ColumnarValue::Scalar(ScalarValue::from(".")),
+ ColumnarValue::Scalar(ScalarValue::from(0i64)),
+ ],
+ Ok(Some("")),
+ &str,
+ Utf8,
+ StringArray
+ );
+ test_function!(
+ SubstrIndexFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("")),
+ ColumnarValue::Scalar(ScalarValue::from(".")),
+ ColumnarValue::Scalar(ScalarValue::from(1i64)),
+ ],
+ Ok(Some("")),
+ &str,
+ Utf8,
+ StringArray
+ );
+ test_function!(
+ SubstrIndexFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
+ ColumnarValue::Scalar(ScalarValue::from("")),
+ ColumnarValue::Scalar(ScalarValue::from(1i64)),
+ ],
+ Ok(Some("")),
+ &str,
+ Utf8,
+ StringArray
+ );
+
+ Ok(())
+ }
}
diff --git a/datafusion/sqllogictest/test_files/functions.slt
b/datafusion/sqllogictest/test_files/functions.slt
index 21433ba168..38ebedf565 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -940,7 +940,8 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM
(VALUES
ROW('arrow.apache.org'),
ROW('.'),
- ROW('...')
+ ROW('...'),
+ ROW(NULL)
) AS strings(str),
(VALUES
ROW(1),
@@ -954,6 +955,14 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM
) AS occurrences(n)
ORDER BY str DESC, n;
----
+NULL -100 NULL
+NULL -3 NULL
+NULL -2 NULL
+NULL -1 NULL
+NULL 1 NULL
+NULL 2 NULL
+NULL 3 NULL
+NULL 100 NULL
arrow.apache.org -100 arrow.apache.org
arrow.apache.org -3 arrow.apache.org
arrow.apache.org -2 apache.org