This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 9bd183f417 functions: support strpos with mixed string types (#12072)
9bd183f417 is described below

commit 9bd183f4171b01bc72f869b92b55dca84d3dd3ae
Author: Nick Cameron <[email protected]>
AuthorDate: Wed Aug 21 06:55:36 2024 +1200

    functions: support strpos with mixed string types (#12072)
    
    Signed-off-by: Nick Cameron <[email protected]>
---
 datafusion/functions/src/unicode/strpos.rs | 81 +++++++++++++++++++++++++-----
 datafusion/functions/src/utils.rs          |  2 +-
 2 files changed, 70 insertions(+), 13 deletions(-)

diff --git a/datafusion/functions/src/unicode/strpos.rs 
b/datafusion/functions/src/unicode/strpos.rs
index 395fd0b77d..702baf6e8f 100644
--- a/datafusion/functions/src/unicode/strpos.rs
+++ b/datafusion/functions/src/unicode/strpos.rs
@@ -78,10 +78,18 @@ impl ScalarUDFImpl for StrposFunc {
     }
 
     fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
-        match args[0].data_type() {
-            DataType::Utf8 => make_scalar_function(strpos::<Int32Type>, 
vec![])(args),
-            DataType::LargeUtf8 => {
-                make_scalar_function(strpos::<Int64Type>, vec![])(args)
+        match (args[0].data_type(), args[1].data_type()) {
+            (DataType::Utf8, DataType::Utf8) => {
+                make_scalar_function(strpos::<Int32Type, Int32Type>, 
vec![])(args)
+            }
+            (DataType::Utf8, DataType::LargeUtf8) => {
+                make_scalar_function(strpos::<Int32Type, Int64Type>, 
vec![])(args)
+            }
+            (DataType::LargeUtf8, DataType::Utf8) => {
+                make_scalar_function(strpos::<Int64Type, Int32Type>, 
vec![])(args)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                make_scalar_function(strpos::<Int64Type, Int64Type>, 
vec![])(args)
             }
             other => exec_err!("Unsupported data type {other:?} for function 
strpos"),
         }
@@ -95,15 +103,18 @@ impl ScalarUDFImpl for StrposFunc {
 /// Returns starting index of specified substring within string, or zero if 
it's not present. (Same as position(substring in string), but note the reversed 
argument order.)
 /// strpos('high', 'ig') = 2
 /// The implementation uses UTF-8 code points as characters
-fn strpos<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
+fn strpos<T0: ArrowPrimitiveType, T1: ArrowPrimitiveType>(
+    args: &[ArrayRef],
+) -> Result<ArrayRef>
 where
-    T::Native: OffsetSizeTrait,
+    T0::Native: OffsetSizeTrait,
+    T1::Native: OffsetSizeTrait,
 {
-    let string_array: &GenericStringArray<T::Native> =
-        as_generic_string_array::<T::Native>(&args[0])?;
+    let string_array: &GenericStringArray<T0::Native> =
+        as_generic_string_array::<T0::Native>(&args[0])?;
 
-    let substring_array: &GenericStringArray<T::Native> =
-        as_generic_string_array::<T::Native>(&args[1])?;
+    let substring_array: &GenericStringArray<T1::Native> =
+        as_generic_string_array::<T1::Native>(&args[1])?;
 
     let result = string_array
         .iter()
@@ -112,7 +123,7 @@ where
             (Some(string), Some(substring)) => {
                 // the find method returns the byte index of the substring
                 // Next, we count the number of the chars until that byte
-                T::Native::from_usize(
+                T0::Native::from_usize(
                     string
                         .find(substring)
                         .map(|x| string[..x].chars().count() + 1)
@@ -121,7 +132,53 @@ where
             }
             _ => None,
         })
-        .collect::<PrimitiveArray<T>>();
+        .collect::<PrimitiveArray<T0>>();
 
     Ok(Arc::new(result) as ArrayRef)
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use crate::utils::test::test_function;
+    use arrow::{
+        array::{Array as _, Int32Array, Int64Array},
+        datatypes::DataType::{Int32, Int64},
+    };
+    use datafusion_common::ScalarValue;
+
+    macro_rules! test_strpos {
+        ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident 
$t3:ident $t4:ident $t5:ident) => {
+            test_function!(
+                StrposFunc::new(),
+                &[
+                    
ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))),
+                    
ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))),
+                ],
+                Ok(Some($result)),
+                $t3,
+                $t4,
+                $t5
+            )
+        };
+    }
+
+    #[test]
+    fn strpos() {
+        test_strpos!("foo", "bar" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
+        test_strpos!("foobar", "foo" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
+        test_strpos!("foobar", "bar" -> 4; Utf8 Utf8 i32 Int32 Int32Array);
+
+        test_strpos!("foo", "bar" -> 0; LargeUtf8 LargeUtf8 i64 Int64 
Int64Array);
+        test_strpos!("foobar", "foo" -> 1; LargeUtf8 LargeUtf8 i64 Int64 
Int64Array);
+        test_strpos!("foobar", "bar" -> 4; LargeUtf8 LargeUtf8 i64 Int64 
Int64Array);
+
+        test_strpos!("foo", "bar" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
+        test_strpos!("foobar", "foo" -> 1; Utf8 LargeUtf8 i32 Int32 
Int32Array);
+        test_strpos!("foobar", "bar" -> 4; Utf8 LargeUtf8 i32 Int32 
Int32Array);
+
+        test_strpos!("foo", "bar" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
+        test_strpos!("foobar", "foo" -> 1; LargeUtf8 Utf8 i64 Int64 
Int64Array);
+        test_strpos!("foobar", "bar" -> 4; LargeUtf8 Utf8 i64 Int64 
Int64Array);
+    }
+}
diff --git a/datafusion/functions/src/utils.rs 
b/datafusion/functions/src/utils.rs
index 7b36717400..d36c5473ba 100644
--- a/datafusion/functions/src/utils.rs
+++ b/datafusion/functions/src/utils.rs
@@ -144,7 +144,7 @@ pub mod test {
                     assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE);
 
                     let result = func.invoke($ARGS);
-                    assert_eq!(result.is_ok(), true);
+                    assert_eq!(result.is_ok(), true, "function returned an 
error: {}", result.unwrap_err());
 
                     let len = $ARGS
                         .iter()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to