This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 9bd183f417 functions: support strpos with mixed string types (#12072)
9bd183f417 is described below
commit 9bd183f4171b01bc72f869b92b55dca84d3dd3ae
Author: Nick Cameron <[email protected]>
AuthorDate: Wed Aug 21 06:55:36 2024 +1200
functions: support strpos with mixed string types (#12072)
Signed-off-by: Nick Cameron <[email protected]>
---
datafusion/functions/src/unicode/strpos.rs | 81 +++++++++++++++++++++++++-----
datafusion/functions/src/utils.rs | 2 +-
2 files changed, 70 insertions(+), 13 deletions(-)
diff --git a/datafusion/functions/src/unicode/strpos.rs
b/datafusion/functions/src/unicode/strpos.rs
index 395fd0b77d..702baf6e8f 100644
--- a/datafusion/functions/src/unicode/strpos.rs
+++ b/datafusion/functions/src/unicode/strpos.rs
@@ -78,10 +78,18 @@ impl ScalarUDFImpl for StrposFunc {
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
- match args[0].data_type() {
- DataType::Utf8 => make_scalar_function(strpos::<Int32Type>,
vec![])(args),
- DataType::LargeUtf8 => {
- make_scalar_function(strpos::<Int64Type>, vec![])(args)
+ match (args[0].data_type(), args[1].data_type()) {
+ (DataType::Utf8, DataType::Utf8) => {
+ make_scalar_function(strpos::<Int32Type, Int32Type>,
vec![])(args)
+ }
+ (DataType::Utf8, DataType::LargeUtf8) => {
+ make_scalar_function(strpos::<Int32Type, Int64Type>,
vec![])(args)
+ }
+ (DataType::LargeUtf8, DataType::Utf8) => {
+ make_scalar_function(strpos::<Int64Type, Int32Type>,
vec![])(args)
+ }
+ (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ make_scalar_function(strpos::<Int64Type, Int64Type>,
vec![])(args)
}
other => exec_err!("Unsupported data type {other:?} for function
strpos"),
}
@@ -95,15 +103,18 @@ impl ScalarUDFImpl for StrposFunc {
/// Returns starting index of specified substring within string, or zero if
it's not present. (Same as position(substring in string), but note the reversed
argument order.)
/// strpos('high', 'ig') = 2
/// The implementation uses UTF-8 code points as characters
-fn strpos<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
+fn strpos<T0: ArrowPrimitiveType, T1: ArrowPrimitiveType>(
+ args: &[ArrayRef],
+) -> Result<ArrayRef>
where
- T::Native: OffsetSizeTrait,
+ T0::Native: OffsetSizeTrait,
+ T1::Native: OffsetSizeTrait,
{
- let string_array: &GenericStringArray<T::Native> =
- as_generic_string_array::<T::Native>(&args[0])?;
+ let string_array: &GenericStringArray<T0::Native> =
+ as_generic_string_array::<T0::Native>(&args[0])?;
- let substring_array: &GenericStringArray<T::Native> =
- as_generic_string_array::<T::Native>(&args[1])?;
+ let substring_array: &GenericStringArray<T1::Native> =
+ as_generic_string_array::<T1::Native>(&args[1])?;
let result = string_array
.iter()
@@ -112,7 +123,7 @@ where
(Some(string), Some(substring)) => {
// the find method returns the byte index of the substring
// Next, we count the number of the chars until that byte
- T::Native::from_usize(
+ T0::Native::from_usize(
string
.find(substring)
.map(|x| string[..x].chars().count() + 1)
@@ -121,7 +132,53 @@ where
}
_ => None,
})
- .collect::<PrimitiveArray<T>>();
+ .collect::<PrimitiveArray<T0>>();
Ok(Arc::new(result) as ArrayRef)
}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use crate::utils::test::test_function;
+ use arrow::{
+ array::{Array as _, Int32Array, Int64Array},
+ datatypes::DataType::{Int32, Int64},
+ };
+ use datafusion_common::ScalarValue;
+
+ macro_rules! test_strpos {
+ ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident
$t3:ident $t4:ident $t5:ident) => {
+ test_function!(
+ StrposFunc::new(),
+ &[
+
ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
+
ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
+ ],
+ Ok(Some($result)),
+ $t3,
+ $t4,
+ $t5
+ )
+ };
+ }
+
+ #[test]
+ fn strpos() {
+ test_strpos!("foo", "bar" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
+ test_strpos!("foobar", "foo" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
+ test_strpos!("foobar", "bar" -> 4; Utf8 Utf8 i32 Int32 Int32Array);
+
+ test_strpos!("foo", "bar" -> 0; LargeUtf8 LargeUtf8 i64 Int64
Int64Array);
+ test_strpos!("foobar", "foo" -> 1; LargeUtf8 LargeUtf8 i64 Int64
Int64Array);
+ test_strpos!("foobar", "bar" -> 4; LargeUtf8 LargeUtf8 i64 Int64
Int64Array);
+
+ test_strpos!("foo", "bar" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
+ test_strpos!("foobar", "foo" -> 1; Utf8 LargeUtf8 i32 Int32
Int32Array);
+ test_strpos!("foobar", "bar" -> 4; Utf8 LargeUtf8 i32 Int32
Int32Array);
+
+ test_strpos!("foo", "bar" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
+ test_strpos!("foobar", "foo" -> 1; LargeUtf8 Utf8 i64 Int64
Int64Array);
+ test_strpos!("foobar", "bar" -> 4; LargeUtf8 Utf8 i64 Int64
Int64Array);
+ }
+}
diff --git a/datafusion/functions/src/utils.rs
b/datafusion/functions/src/utils.rs
index 7b36717400..d36c5473ba 100644
--- a/datafusion/functions/src/utils.rs
+++ b/datafusion/functions/src/utils.rs
@@ -144,7 +144,7 @@ pub mod test {
assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE);
let result = func.invoke($ARGS);
- assert_eq!(result.is_ok(), true);
+ assert_eq!(result.is_ok(), true, "function returned an
error: {}", result.unwrap_err());
let len = $ARGS
.iter()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]