This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2e5c6b28fb Improve performance of `find_in_set` function (#14020)
2e5c6b28fb is described below

commit 2e5c6b28fbb6a7cb7d4351ebc005af24abdb6d54
Author: Tai Le Manh <[email protected]>
AuthorDate: Sun Jan 12 19:23:20 2025 +0700

    Improve performance of `find_in_set` function (#14020)
    
    * Improve performance of 'find_in_set' function
    
    Signed-off-by: Tai Le Manh <[email protected]>
    
    * Remove clippy warnings
    
    * Support scalar args for 'find_in_set' function
    
    Signed-off-by: Tai Le Manh <[email protected]>
    
    ---------
    
    Signed-off-by: Tai Le Manh <[email protected]>
---
 datafusion/functions/Cargo.toml                 |   5 +
 datafusion/functions/benches/find_in_set.rs     | 208 +++++++++++
 datafusion/functions/src/unicode/find_in_set.rs | 461 +++++++++++++++++++++---
 datafusion/functions/src/unicode/mod.rs         |   2 +-
 4 files changed, 631 insertions(+), 45 deletions(-)

diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml
index c8025fb2d8..db3e6838f6 100644
--- a/datafusion/functions/Cargo.toml
+++ b/datafusion/functions/Cargo.toml
@@ -218,3 +218,8 @@ required-features = ["math_expressions"]
 harness = false
 name = "initcap"
 required-features = ["unicode_expressions"]
+
+[[bench]]
+harness = false
+name = "find_in_set"
+required-features = ["unicode_expressions"]
diff --git a/datafusion/functions/benches/find_in_set.rs 
b/datafusion/functions/benches/find_in_set.rs
new file mode 100644
index 0000000000..9307525482
--- /dev/null
+++ b/datafusion/functions/benches/find_in_set.rs
@@ -0,0 +1,208 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate criterion;
+
+use arrow::array::{StringArray, StringViewArray};
+use arrow::datatypes::DataType;
+use arrow::util::bench_util::{
+    create_string_array_with_len, create_string_view_array_with_len,
+};
+use criterion::{black_box, criterion_group, criterion_main, Criterion, 
SamplingMode};
+use datafusion_common::ScalarValue;
+use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
+use rand::distributions::Alphanumeric;
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
+use std::sync::Arc;
+use std::time::Duration;
+
+/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with
+/// 4096 rows, each row containing a string with 128 random characters.
+/// around 10% of the rows are null, around 10% of the rows are non-ASCII.
+fn gen_args_array(
+    n_rows: usize,
+    str_len_chars: usize,
+    null_density: f32,
+    utf8_density: f32,
+    is_string_view: bool, // false -> StringArray, true -> StringViewArray
+) -> Vec<ColumnarValue> {
+    let mut rng = StdRng::seed_from_u64(42);
+    let rng_ref = &mut rng;
+
+    let num_elements = 5; // 5 elements separated by comma
+    let utf8 = "DataFusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 
1~4 bytes
+    let corpus_char_count = utf8.chars().count();
+
+    let mut output_set_vec: Vec<Option<String>> = Vec::with_capacity(n_rows);
+    let mut output_element_vec: Vec<Option<String>> = 
Vec::with_capacity(n_rows);
+    for _ in 0..n_rows {
+        let rand_num = rng_ref.gen::<f32>(); // [0.0, 1.0)
+        if rand_num < null_density {
+            output_element_vec.push(None);
+            output_set_vec.push(None);
+        } else if rand_num < null_density + utf8_density {
+            // Generate random UTF-8 string with comma separators
+            let mut generated_string = String::with_capacity(str_len_chars);
+            for i in 0..num_elements {
+                for _ in 0..str_len_chars {
+                    let idx = rng_ref.gen_range(0..corpus_char_count);
+                    let char = utf8.chars().nth(idx).unwrap();
+                    generated_string.push(char);
+                }
+                if i < num_elements - 1 {
+                    generated_string.push(',');
+                }
+            }
+            
output_element_vec.push(Some(random_element_in_set(&generated_string)));
+            output_set_vec.push(Some(generated_string));
+        } else {
+            // Generate random ASCII-only string with comma separators
+            let mut generated_string = String::with_capacity(str_len_chars);
+            for i in 0..num_elements {
+                for _ in 0..str_len_chars {
+                    let c = rng_ref.sample(Alphanumeric);
+                    generated_string.push(c as char);
+                }
+                if i < num_elements - 1 {
+                    generated_string.push(',');
+                }
+            }
+            
output_element_vec.push(Some(random_element_in_set(&generated_string)));
+            output_set_vec.push(Some(generated_string));
+        }
+    }
+
+    if is_string_view {
+        let set_array: StringViewArray = output_set_vec.into_iter().collect();
+        let element_array: StringViewArray = 
output_element_vec.into_iter().collect();
+        vec![
+            ColumnarValue::Array(Arc::new(element_array)),
+            ColumnarValue::Array(Arc::new(set_array)),
+        ]
+    } else {
+        let set_array: StringArray = 
output_set_vec.clone().into_iter().collect();
+        let element_array: StringArray = 
output_element_vec.into_iter().collect();
+        vec![
+            ColumnarValue::Array(Arc::new(element_array)),
+            ColumnarValue::Array(Arc::new(set_array)),
+        ]
+    }
+}
+
+fn random_element_in_set(string: &str) -> String {
+    let elements: Vec<&str> = string.split(',').collect();
+
+    if elements.is_empty() || (elements.len() == 1 && elements[0].is_empty()) {
+        return String::new();
+    }
+
+    let mut rng = StdRng::seed_from_u64(44);
+    let random_index = rng.gen_range(0..elements.len());
+
+    elements[random_index].to_string()
+}
+
+fn gen_args_scalar(
+    n_rows: usize,
+    str_len_chars: usize,
+    null_density: f32,
+    is_string_view: bool, // false -> StringArray, true -> StringViewArray
+) -> Vec<ColumnarValue> {
+    let str_list = "Apache,DataFusion,SQL,Query,Engine".to_string();
+    if is_string_view {
+        let string =
+            create_string_view_array_with_len(n_rows, null_density, 
str_len_chars, false);
+        vec![
+            ColumnarValue::Array(Arc::new(string)),
+            ColumnarValue::Scalar(ScalarValue::Utf8(Some(str_list))),
+        ]
+    } else {
+        let string =
+            create_string_array_with_len::<i32>(n_rows, null_density, 
str_len_chars);
+        vec![
+            ColumnarValue::Array(Arc::new(string)),
+            ColumnarValue::Scalar(ScalarValue::Utf8(Some(str_list))),
+        ]
+    }
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+    // All benches are single batch run with 8192 rows
+    let find_in_set = datafusion_functions::unicode::find_in_set();
+
+    let n_rows = 8192;
+    for str_len in [8, 32, 1024] {
+        let mut group = c.benchmark_group("find_in_set");
+        group.sampling_mode(SamplingMode::Flat);
+        group.sample_size(50);
+        group.measurement_time(Duration::from_secs(10));
+
+        let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false);
+        group.bench_function(format!("string_len_{}", str_len), |b| {
+            b.iter(|| {
+                black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+                    args: args.clone(),
+                    number_rows: n_rows,
+                    return_type: &DataType::Int32,
+                }))
+            })
+        });
+
+        let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true);
+        group.bench_function(format!("string_view_len_{}", str_len), |b| {
+            b.iter(|| {
+                black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+                    args: args.clone(),
+                    number_rows: n_rows,
+                    return_type: &DataType::Int32,
+                }))
+            })
+        });
+
+        group.finish();
+
+        let mut group = c.benchmark_group("find_in_set_scalar");
+
+        let args = gen_args_scalar(n_rows, str_len, 0.1, false);
+        group.bench_function(format!("string_len_{}", str_len), |b| {
+            b.iter(|| {
+                black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+                    args: args.clone(),
+                    number_rows: n_rows,
+                    return_type: &DataType::Int32,
+                }))
+            })
+        });
+
+        let args = gen_args_scalar(n_rows, str_len, 0.1, true);
+        group.bench_function(format!("string_view_len_{}", str_len), |b| {
+            b.iter(|| {
+                black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+                    args: args.clone(),
+                    number_rows: n_rows,
+                    return_type: &DataType::Int32,
+                }))
+            })
+        });
+
+        group.finish();
+    }
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/functions/src/unicode/find_in_set.rs 
b/datafusion/functions/src/unicode/find_in_set.rs
index c4d9b51f60..12f213a827 100644
--- a/datafusion/functions/src/unicode/find_in_set.rs
+++ b/datafusion/functions/src/unicode/find_in_set.rs
@@ -19,16 +19,17 @@ use std::any::Any;
 use std::sync::Arc;
 
 use arrow::array::{
-    ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, 
OffsetSizeTrait,
-    PrimitiveArray,
+    new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, 
AsArray,
+    OffsetSizeTrait, PrimitiveArray,
 };
 use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
 
-use crate::utils::{make_scalar_function, utf8_to_int_type};
-use datafusion_common::{exec_err, Result};
+use crate::utils::utf8_to_int_type;
+use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
 use datafusion_expr::TypeSignature::Exact;
 use datafusion_expr::{
-    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
+    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
+    Volatility,
 };
 use datafusion_macros::user_doc;
 
@@ -42,7 +43,7 @@ use datafusion_macros::user_doc;
 | find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
 +----------------------------------------+
 | 2                                      |
-+----------------------------------------+ 
++----------------------------------------+
 ```"#,
     argument(name = "str", description = "String expression to find in 
strlist."),
     argument(
@@ -94,12 +95,141 @@ impl ScalarUDFImpl for FindInSetFunc {
         utf8_to_int_type(&arg_types[0], "find_in_set")
     }
 
-    fn invoke_batch(
-        &self,
-        args: &[ColumnarValue],
-        _number_rows: usize,
-    ) -> Result<ColumnarValue> {
-        make_scalar_function(find_in_set, vec![])(args)
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        let ScalarFunctionArgs { mut args, .. } = args;
+
+        if args.len() != 2 {
+            return exec_err!(
+                "find_in_set was called with {} arguments. It requires 2.",
+                args.len()
+            );
+        }
+
+        let str_list = args.pop().unwrap();
+        let string = args.pop().unwrap();
+
+        match (string, str_list) {
+            // both inputs are scalars
+            (
+                ColumnarValue::Scalar(
+                    ScalarValue::Utf8View(string)
+                    | ScalarValue::Utf8(string)
+                    | ScalarValue::LargeUtf8(string),
+                ),
+                ColumnarValue::Scalar(
+                    ScalarValue::Utf8View(str_list)
+                    | ScalarValue::Utf8(str_list)
+                    | ScalarValue::LargeUtf8(str_list),
+                ),
+            ) => {
+                let res = match (string, str_list) {
+                    (Some(string), Some(str_list)) => {
+                        let position = str_list
+                            .split(',')
+                            .position(|s| s == string)
+                            .map_or(0, |idx| idx + 1);
+
+                        Some(position as i32)
+                    }
+                    _ => None,
+                };
+                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
+            }
+
+            // `string` is an array, `str_list` is scalar
+            (
+                ColumnarValue::Array(str_array),
+                ColumnarValue::Scalar(
+                    ScalarValue::Utf8View(str_list_literal)
+                    | ScalarValue::Utf8(str_list_literal)
+                    | ScalarValue::LargeUtf8(str_list_literal),
+                ),
+            ) => {
+                let result_array = match str_list_literal {
+                    // find_in_set(column_a, null) = null
+                    None => new_null_array(str_array.data_type(), 
str_array.len()),
+                    Some(str_list_literal) => {
+                        let str_list = 
str_list_literal.split(',').collect::<Vec<&str>>();
+                        let result = match str_array.data_type() {
+                            DataType::Utf8 => {
+                                let string_array = 
str_array.as_string::<i32>();
+                                find_in_set_right_literal::<Int32Type, _>(
+                                    string_array,
+                                    str_list,
+                                )
+                            }
+                            DataType::LargeUtf8 => {
+                                let string_array = 
str_array.as_string::<i64>();
+                                find_in_set_right_literal::<Int64Type, _>(
+                                    string_array,
+                                    str_list,
+                                )
+                            }
+                            DataType::Utf8View => {
+                                let string_array = str_array.as_string_view();
+                                find_in_set_right_literal::<Int32Type, _>(
+                                    string_array,
+                                    str_list,
+                                )
+                            }
+                            other => {
+                                exec_err!("Unsupported data type {other:?} for 
function find_in_set")
+                            }
+                        };
+                        Arc::new(result?)
+                    }
+                };
+                Ok(ColumnarValue::Array(result_array))
+            }
+
+            // `string` is scalar, `str_list` is an array
+            (
+                ColumnarValue::Scalar(
+                    ScalarValue::Utf8View(string_literal)
+                    | ScalarValue::Utf8(string_literal)
+                    | ScalarValue::LargeUtf8(string_literal),
+                ),
+                ColumnarValue::Array(str_list_array),
+            ) => {
+                let res = match string_literal {
+                    // find_in_set(null, column_b) = null
+                    None => {
+                        new_null_array(str_list_array.data_type(), 
str_list_array.len())
+                    }
+                    Some(string) => {
+                        let result = match str_list_array.data_type() {
+                            DataType::Utf8 => {
+                                let str_list = 
str_list_array.as_string::<i32>();
+                                find_in_set_left_literal::<Int32Type, 
_>(string, str_list)
+                            }
+                            DataType::LargeUtf8 => {
+                                let str_list = 
str_list_array.as_string::<i64>();
+                                find_in_set_left_literal::<Int64Type, 
_>(string, str_list)
+                            }
+                            DataType::Utf8View => {
+                                let str_list = str_list_array.as_string_view();
+                                find_in_set_left_literal::<Int32Type, 
_>(string, str_list)
+                            }
+                            other => {
+                                exec_err!("Unsupported data type {other:?} for 
function find_in_set")
+                            }
+                        };
+                        Arc::new(result?)
+                    }
+                };
+                Ok(ColumnarValue::Array(res))
+            }
+
+            // both inputs are arrays
+            (ColumnarValue::Array(base_array), 
ColumnarValue::Array(exp_array)) => {
+                let res = find_in_set(base_array, exp_array)?;
+
+                Ok(ColumnarValue::Array(res))
+            }
+            _ => {
+                internal_err!("Invalid argument types for `find_in_set` 
function")
+            }
+        }
     }
 
     fn documentation(&self) -> Option<&Documentation> {
@@ -107,29 +237,24 @@ impl ScalarUDFImpl for FindInSetFunc {
     }
 }
 
-///Returns a value in the range of 1 to N if the string str is in the string 
list strlist consisting of N substrings
-///A string list is a string composed of substrings separated by , characters.
-fn find_in_set(args: &[ArrayRef]) -> Result<ArrayRef> {
-    if args.len() != 2 {
-        return exec_err!(
-            "find_in_set was called with {} arguments. It requires 2.",
-            args.len()
-        );
-    }
-    match args[0].data_type() {
+/// Returns a value in the range of 1 to N if the string `str` is in the 
string list `strlist`
+/// consisting of N substrings. A string list is a string composed of 
substrings separated by `,`
+/// characters.
+fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result<ArrayRef> {
+    match str.data_type() {
         DataType::Utf8 => {
-            let string_array = args[0].as_string::<i32>();
-            let str_list_array = args[1].as_string::<i32>();
+            let string_array = str.as_string::<i32>();
+            let str_list_array = str_list.as_string::<i32>();
             find_in_set_general::<Int32Type, _>(string_array, str_list_array)
         }
         DataType::LargeUtf8 => {
-            let string_array = args[0].as_string::<i64>();
-            let str_list_array = args[1].as_string::<i64>();
+            let string_array = str.as_string::<i64>();
+            let str_list_array = str_list.as_string::<i64>();
             find_in_set_general::<Int64Type, _>(string_array, str_list_array)
         }
         DataType::Utf8View => {
-            let string_array = args[0].as_string_view();
-            let str_list_array = args[1].as_string_view();
+            let string_array = str.as_string_view();
+            let str_list_array = str_list.as_string_view();
             find_in_set_general::<Int32Type, _>(string_array, str_list_array)
         }
         other => {
@@ -138,31 +263,279 @@ fn find_in_set(args: &[ArrayRef]) -> Result<ArrayRef> {
     }
 }
 
-pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor<Item = 
&'a str>>(
+pub fn find_in_set_general<'a, T, V>(
     string_array: V,
     str_list_array: V,
 ) -> Result<ArrayRef>
 where
+    T: ArrowPrimitiveType,
     T::Native: OffsetSizeTrait,
+    V: ArrayAccessor<Item = &'a str>,
 {
     let string_iter = ArrayIter::new(string_array);
     let str_list_iter = ArrayIter::new(str_list_array);
-    let result = string_iter
+
+    let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
+
+    string_iter
         .zip(str_list_iter)
-        .map(|(string, str_list)| match (string, str_list) {
-            (Some(string), Some(str_list)) => {
-                let mut res = 0;
-                let str_set: Vec<&str> = str_list.split(',').collect();
-                for (idx, str) in str_set.iter().enumerate() {
-                    if str == &string {
-                        res = idx + 1;
-                        break;
-                    }
+        .for_each(
+            |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
+                (Some(string), Some(str_list)) => {
+                    let position = str_list
+                        .split(',')
+                        .position(|s| s == string)
+                        .map_or(0, |idx| idx + 1);
+                    
builder.append_value(T::Native::from_usize(position).unwrap());
                 }
-                T::Native::from_usize(res)
+                _ => builder.append_null(),
+            },
+        );
+
+    Ok(Arc::new(builder.finish()) as ArrayRef)
+}
+
+fn find_in_set_left_literal<'a, T, V>(
+    string: String,
+    str_list_array: V,
+) -> Result<ArrayRef>
+where
+    T: ArrowPrimitiveType,
+    T::Native: OffsetSizeTrait,
+    V: ArrayAccessor<Item = &'a str>,
+{
+    let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
+
+    let str_list_iter = ArrayIter::new(str_list_array);
+
+    str_list_iter.for_each(|str_list_opt| match str_list_opt {
+        Some(str_list) => {
+            let position = str_list
+                .split(',')
+                .position(|s| s == string)
+                .map_or(0, |idx| idx + 1);
+            builder.append_value(T::Native::from_usize(position).unwrap());
+        }
+        None => builder.append_null(),
+    });
+
+    Ok(Arc::new(builder.finish()) as ArrayRef)
+}
+
+fn find_in_set_right_literal<'a, T, V>(
+    string_array: V,
+    str_list: Vec<&str>,
+) -> Result<ArrayRef>
+where
+    T: ArrowPrimitiveType,
+    T::Native: OffsetSizeTrait,
+    V: ArrayAccessor<Item = &'a str>,
+{
+    let mut builder = PrimitiveArray::<T>::builder(string_array.len());
+
+    let string_iter = ArrayIter::new(string_array);
+
+    string_iter.for_each(|string_opt| match string_opt {
+        Some(string) => {
+            let position = str_list
+                .iter()
+                .position(|s| *s == string)
+                .map_or(0, |idx| idx + 1);
+            builder.append_value(T::Native::from_usize(position).unwrap());
+        }
+        None => builder.append_null(),
+    });
+
+    Ok(Arc::new(builder.finish()) as ArrayRef)
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::unicode::find_in_set::FindInSetFunc;
+    use crate::utils::test::test_function;
+    use arrow::array::{Array, Int32Array, StringArray};
+    use arrow::datatypes::DataType::Int32;
+    use datafusion_common::{Result, ScalarValue};
+    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
+    use std::sync::Arc;
+
+    #[test]
+    fn test_functions() -> Result<()> {
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
+            ],
+            Ok(Some(1)),
+            i32,
+            Int32,
+            Int32Array
+        );
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
+                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+                    "a,Д,🔥"
+                )))),
+            ],
+            Ok(Some(3)),
+            i32,
+            Int32,
+            Int32Array
+        );
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
+                
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
+            ],
+            Ok(Some(0)),
+            i32,
+            Int32,
+            Int32Array
+        );
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+                    "Apache Software Foundation"
+                )))),
+                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+                    "Github,Apache Software Foundation,DataFusion"
+                )))),
+            ],
+            Ok(Some(2)),
+            i32,
+            Int32,
+            Int32Array
+        );
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
+            ],
+            Ok(Some(0)),
+            i32,
+            Int32,
+            Int32Array
+        );
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
+            ],
+            Ok(Some(0)),
+            i32,
+            Int32,
+            Int32Array
+        );
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
+                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
+            ],
+            Ok(None),
+            i32,
+            Int32,
+            Int32Array
+        );
+        test_function!(
+            FindInSetFunc::new(),
+            vec![
+                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
+                
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
+            ],
+            Ok(None),
+            i32,
+            Int32,
+            Int32Array
+        );
+
+        Ok(())
+    }
+
+    macro_rules! test_find_in_set {
+        ($test_name:ident, $args:expr, $expected:expr) => {
+            #[test]
+            fn $test_name() -> Result<()> {
+                let fis = crate::unicode::find_in_set();
+
+                let args = $args;
+                let expected = $expected;
+
+                let type_array = args.iter().map(|a| 
a.data_type()).collect::<Vec<_>>();
+                let cardinality = args
+                    .iter()
+                    .fold(Option::<usize>::None, |acc, arg| match arg {
+                        ColumnarValue::Scalar(_) => acc,
+                        ColumnarValue::Array(a) => Some(a.len()),
+                    })
+                    .unwrap_or(1);
+                let return_type = fis.return_type(&type_array)?;
+                let result = fis.invoke_with_args(ScalarFunctionArgs {
+                    args,
+                    number_rows: cardinality,
+                    return_type: &return_type,
+                });
+                assert!(result.is_ok());
+
+                let result = result?
+                    .to_array(cardinality)
+                    .expect("Failed to convert to array");
+                let result = result
+                    .as_any()
+                    .downcast_ref::<Int32Array>()
+                    .expect("Failed to convert to type");
+                assert_eq!(*result, expected);
+
+                Ok(())
             }
-            _ => None,
-        })
-        .collect::<PrimitiveArray<T>>();
-    Ok(Arc::new(result) as ArrayRef)
+        };
+    }
+
+    test_find_in_set!(
+        test_find_in_set_with_scalar_args,
+        vec![
+            ColumnarValue::Array(Arc::new(StringArray::from(vec![
+                "", "a", "b", "c", "d"
+            ]))),
+            
ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
+        ],
+        Int32Array::from(vec![0, 0, 1, 2, 3])
+    );
+    test_find_in_set!(
+        test_find_in_set_with_scalar_args_2,
+        vec![
+            ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
+                "ApacheSoftware".to_string()
+            ))),
+            ColumnarValue::Array(Arc::new(StringArray::from(vec![
+                "a,b,c",
+                "ApacheSoftware,Github,DataFusion",
+                ""
+            ]))),
+        ],
+        Int32Array::from(vec![0, 1, 0])
+    );
+    test_find_in_set!(
+        test_find_in_set_with_scalar_args_3,
+        vec![
+            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 
3]))),
+            
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
+        ],
+        Int32Array::from(vec![None::<i32>; 3])
+    );
+    test_find_in_set!(
+        test_find_in_set_with_scalar_args_4,
+        vec![
+            
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
+            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 
3]))),
+        ],
+        Int32Array::from(vec![None::<i32>; 3])
+    );
 }
diff --git a/datafusion/functions/src/unicode/mod.rs 
b/datafusion/functions/src/unicode/mod.rs
index e8e3eb3f4e..3c5cde3789 100644
--- a/datafusion/functions/src/unicode/mod.rs
+++ b/datafusion/functions/src/unicode/mod.rs
@@ -102,7 +102,7 @@ pub mod expr_fn {
         string
     ),(
         find_in_set,
-        "Returns a value in the range of 1 to N if the string str is in the 
string list strlist consisting of N substrings",
+        "Returns a value in the range of 1 to N if the string `str` is in the 
string list `strlist` consisting of N substrings",
         string strlist
     ));
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to