This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2e5c6b28fb Improve performance of `find_in_set` function (#14020)
2e5c6b28fb is described below
commit 2e5c6b28fbb6a7cb7d4351ebc005af24abdb6d54
Author: Tai Le Manh <[email protected]>
AuthorDate: Sun Jan 12 19:23:20 2025 +0700
Improve performance of `find_in_set` function (#14020)
* Improve performance of 'find_in_set' function
Signed-off-by: Tai Le Manh <[email protected]>
* Remove clippy warnings
* Support scalar args for 'find_in_set' function
Signed-off-by: Tai Le Manh <[email protected]>
---------
Signed-off-by: Tai Le Manh <[email protected]>
---
datafusion/functions/Cargo.toml | 5 +
datafusion/functions/benches/find_in_set.rs | 208 +++++++++++
datafusion/functions/src/unicode/find_in_set.rs | 461 +++++++++++++++++++++---
datafusion/functions/src/unicode/mod.rs | 2 +-
4 files changed, 631 insertions(+), 45 deletions(-)
diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml
index c8025fb2d8..db3e6838f6 100644
--- a/datafusion/functions/Cargo.toml
+++ b/datafusion/functions/Cargo.toml
@@ -218,3 +218,8 @@ required-features = ["math_expressions"]
harness = false
name = "initcap"
required-features = ["unicode_expressions"]
+
+[[bench]]
+harness = false
+name = "find_in_set"
+required-features = ["unicode_expressions"]
diff --git a/datafusion/functions/benches/find_in_set.rs
b/datafusion/functions/benches/find_in_set.rs
new file mode 100644
index 0000000000..9307525482
--- /dev/null
+++ b/datafusion/functions/benches/find_in_set.rs
@@ -0,0 +1,208 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate criterion;
+
+use arrow::array::{StringArray, StringViewArray};
+use arrow::datatypes::DataType;
+use arrow::util::bench_util::{
+ create_string_array_with_len, create_string_view_array_with_len,
+};
+use criterion::{black_box, criterion_group, criterion_main, Criterion,
SamplingMode};
+use datafusion_common::ScalarValue;
+use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
+use rand::distributions::Alphanumeric;
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
+use std::sync::Arc;
+use std::time::Duration;
+
+/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with
+/// 4096 rows, each row containing a string with 128 random characters.
+/// around 10% of the rows are null, around 10% of the rows are non-ASCII.
+fn gen_args_array(
+ n_rows: usize,
+ str_len_chars: usize,
+ null_density: f32,
+ utf8_density: f32,
+ is_string_view: bool, // false -> StringArray, true -> StringViewArray
+) -> Vec<ColumnarValue> {
+ let mut rng = StdRng::seed_from_u64(42);
+ let rng_ref = &mut rng;
+
+ let num_elements = 5; // 5 elements separated by comma
+ let utf8 = "DataFusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with
1~4 bytes
+ let corpus_char_count = utf8.chars().count();
+
+ let mut output_set_vec: Vec<Option<String>> = Vec::with_capacity(n_rows);
+ let mut output_element_vec: Vec<Option<String>> =
Vec::with_capacity(n_rows);
+ for _ in 0..n_rows {
+ let rand_num = rng_ref.gen::<f32>(); // [0.0, 1.0)
+ if rand_num < null_density {
+ output_element_vec.push(None);
+ output_set_vec.push(None);
+ } else if rand_num < null_density + utf8_density {
+ // Generate random UTF-8 string with comma separators
+ let mut generated_string = String::with_capacity(str_len_chars);
+ for i in 0..num_elements {
+ for _ in 0..str_len_chars {
+ let idx = rng_ref.gen_range(0..corpus_char_count);
+ let char = utf8.chars().nth(idx).unwrap();
+ generated_string.push(char);
+ }
+ if i < num_elements - 1 {
+ generated_string.push(',');
+ }
+ }
+
output_element_vec.push(Some(random_element_in_set(&generated_string)));
+ output_set_vec.push(Some(generated_string));
+ } else {
+ // Generate random ASCII-only string with comma separators
+ let mut generated_string = String::with_capacity(str_len_chars);
+ for i in 0..num_elements {
+ for _ in 0..str_len_chars {
+ let c = rng_ref.sample(Alphanumeric);
+ generated_string.push(c as char);
+ }
+ if i < num_elements - 1 {
+ generated_string.push(',');
+ }
+ }
+
output_element_vec.push(Some(random_element_in_set(&generated_string)));
+ output_set_vec.push(Some(generated_string));
+ }
+ }
+
+ if is_string_view {
+ let set_array: StringViewArray = output_set_vec.into_iter().collect();
+ let element_array: StringViewArray =
output_element_vec.into_iter().collect();
+ vec![
+ ColumnarValue::Array(Arc::new(element_array)),
+ ColumnarValue::Array(Arc::new(set_array)),
+ ]
+ } else {
+ let set_array: StringArray =
output_set_vec.clone().into_iter().collect();
+ let element_array: StringArray =
output_element_vec.into_iter().collect();
+ vec![
+ ColumnarValue::Array(Arc::new(element_array)),
+ ColumnarValue::Array(Arc::new(set_array)),
+ ]
+ }
+}
+
+fn random_element_in_set(string: &str) -> String {
+ let elements: Vec<&str> = string.split(',').collect();
+
+ if elements.is_empty() || (elements.len() == 1 && elements[0].is_empty()) {
+ return String::new();
+ }
+
+ let mut rng = StdRng::seed_from_u64(44);
+ let random_index = rng.gen_range(0..elements.len());
+
+ elements[random_index].to_string()
+}
+
+fn gen_args_scalar(
+ n_rows: usize,
+ str_len_chars: usize,
+ null_density: f32,
+ is_string_view: bool, // false -> StringArray, true -> StringViewArray
+) -> Vec<ColumnarValue> {
+ let str_list = "Apache,DataFusion,SQL,Query,Engine".to_string();
+ if is_string_view {
+ let string =
+ create_string_view_array_with_len(n_rows, null_density,
str_len_chars, false);
+ vec![
+ ColumnarValue::Array(Arc::new(string)),
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(str_list))),
+ ]
+ } else {
+ let string =
+ create_string_array_with_len::<i32>(n_rows, null_density,
str_len_chars);
+ vec![
+ ColumnarValue::Array(Arc::new(string)),
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(str_list))),
+ ]
+ }
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ // All benches are single batch run with 8192 rows
+ let find_in_set = datafusion_functions::unicode::find_in_set();
+
+ let n_rows = 8192;
+ for str_len in [8, 32, 1024] {
+ let mut group = c.benchmark_group("find_in_set");
+ group.sampling_mode(SamplingMode::Flat);
+ group.sample_size(50);
+ group.measurement_time(Duration::from_secs(10));
+
+ let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false);
+ group.bench_function(format!("string_len_{}", str_len), |b| {
+ b.iter(|| {
+ black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+ args: args.clone(),
+ number_rows: n_rows,
+ return_type: &DataType::Int32,
+ }))
+ })
+ });
+
+ let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true);
+ group.bench_function(format!("string_view_len_{}", str_len), |b| {
+ b.iter(|| {
+ black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+ args: args.clone(),
+ number_rows: n_rows,
+ return_type: &DataType::Int32,
+ }))
+ })
+ });
+
+ group.finish();
+
+ let mut group = c.benchmark_group("find_in_set_scalar");
+
+ let args = gen_args_scalar(n_rows, str_len, 0.1, false);
+ group.bench_function(format!("string_len_{}", str_len), |b| {
+ b.iter(|| {
+ black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+ args: args.clone(),
+ number_rows: n_rows,
+ return_type: &DataType::Int32,
+ }))
+ })
+ });
+
+ let args = gen_args_scalar(n_rows, str_len, 0.1, true);
+ group.bench_function(format!("string_view_len_{}", str_len), |b| {
+ b.iter(|| {
+ black_box(find_in_set.invoke_with_args(ScalarFunctionArgs {
+ args: args.clone(),
+ number_rows: n_rows,
+ return_type: &DataType::Int32,
+ }))
+ })
+ });
+
+ group.finish();
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/functions/src/unicode/find_in_set.rs
b/datafusion/functions/src/unicode/find_in_set.rs
index c4d9b51f60..12f213a827 100644
--- a/datafusion/functions/src/unicode/find_in_set.rs
+++ b/datafusion/functions/src/unicode/find_in_set.rs
@@ -19,16 +19,17 @@ use std::any::Any;
use std::sync::Arc;
use arrow::array::{
- ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
OffsetSizeTrait,
- PrimitiveArray,
+ new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType,
AsArray,
+ OffsetSizeTrait, PrimitiveArray,
};
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
-use crate::utils::{make_scalar_function, utf8_to_int_type};
-use datafusion_common::{exec_err, Result};
+use crate::utils::utf8_to_int_type;
+use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{
- ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
+ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
+ Volatility,
};
use datafusion_macros::user_doc;
@@ -42,7 +43,7 @@ use datafusion_macros::user_doc;
| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
+----------------------------------------+
| 2 |
-+----------------------------------------+
++----------------------------------------+
```"#,
argument(name = "str", description = "String expression to find in
strlist."),
argument(
@@ -94,12 +95,141 @@ impl ScalarUDFImpl for FindInSetFunc {
utf8_to_int_type(&arg_types[0], "find_in_set")
}
- fn invoke_batch(
- &self,
- args: &[ColumnarValue],
- _number_rows: usize,
- ) -> Result<ColumnarValue> {
- make_scalar_function(find_in_set, vec![])(args)
+ fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
+ let ScalarFunctionArgs { mut args, .. } = args;
+
+ if args.len() != 2 {
+ return exec_err!(
+ "find_in_set was called with {} arguments. It requires 2.",
+ args.len()
+ );
+ }
+
+ let str_list = args.pop().unwrap();
+ let string = args.pop().unwrap();
+
+ match (string, str_list) {
+ // both inputs are scalars
+ (
+ ColumnarValue::Scalar(
+ ScalarValue::Utf8View(string)
+ | ScalarValue::Utf8(string)
+ | ScalarValue::LargeUtf8(string),
+ ),
+ ColumnarValue::Scalar(
+ ScalarValue::Utf8View(str_list)
+ | ScalarValue::Utf8(str_list)
+ | ScalarValue::LargeUtf8(str_list),
+ ),
+ ) => {
+ let res = match (string, str_list) {
+ (Some(string), Some(str_list)) => {
+ let position = str_list
+ .split(',')
+ .position(|s| s == string)
+ .map_or(0, |idx| idx + 1);
+
+ Some(position as i32)
+ }
+ _ => None,
+ };
+ Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
+ }
+
+ // `string` is an array, `str_list` is scalar
+ (
+ ColumnarValue::Array(str_array),
+ ColumnarValue::Scalar(
+ ScalarValue::Utf8View(str_list_literal)
+ | ScalarValue::Utf8(str_list_literal)
+ | ScalarValue::LargeUtf8(str_list_literal),
+ ),
+ ) => {
+ let result_array = match str_list_literal {
+ // find_in_set(column_a, null) = null
+ None => new_null_array(str_array.data_type(),
str_array.len()),
+ Some(str_list_literal) => {
+ let str_list =
str_list_literal.split(',').collect::<Vec<&str>>();
+ let result = match str_array.data_type() {
+ DataType::Utf8 => {
+ let string_array =
str_array.as_string::<i32>();
+ find_in_set_right_literal::<Int32Type, _>(
+ string_array,
+ str_list,
+ )
+ }
+ DataType::LargeUtf8 => {
+ let string_array =
str_array.as_string::<i64>();
+ find_in_set_right_literal::<Int64Type, _>(
+ string_array,
+ str_list,
+ )
+ }
+ DataType::Utf8View => {
+ let string_array = str_array.as_string_view();
+ find_in_set_right_literal::<Int32Type, _>(
+ string_array,
+ str_list,
+ )
+ }
+ other => {
+ exec_err!("Unsupported data type {other:?} for
function find_in_set")
+ }
+ };
+ Arc::new(result?)
+ }
+ };
+ Ok(ColumnarValue::Array(result_array))
+ }
+
+ // `string` is scalar, `str_list` is an array
+ (
+ ColumnarValue::Scalar(
+ ScalarValue::Utf8View(string_literal)
+ | ScalarValue::Utf8(string_literal)
+ | ScalarValue::LargeUtf8(string_literal),
+ ),
+ ColumnarValue::Array(str_list_array),
+ ) => {
+ let res = match string_literal {
+ // find_in_set(null, column_b) = null
+ None => {
+ new_null_array(str_list_array.data_type(),
str_list_array.len())
+ }
+ Some(string) => {
+ let result = match str_list_array.data_type() {
+ DataType::Utf8 => {
+ let str_list =
str_list_array.as_string::<i32>();
+ find_in_set_left_literal::<Int32Type,
_>(string, str_list)
+ }
+ DataType::LargeUtf8 => {
+ let str_list =
str_list_array.as_string::<i64>();
+ find_in_set_left_literal::<Int64Type,
_>(string, str_list)
+ }
+ DataType::Utf8View => {
+ let str_list = str_list_array.as_string_view();
+ find_in_set_left_literal::<Int32Type,
_>(string, str_list)
+ }
+ other => {
+ exec_err!("Unsupported data type {other:?} for
function find_in_set")
+ }
+ };
+ Arc::new(result?)
+ }
+ };
+ Ok(ColumnarValue::Array(res))
+ }
+
+ // both inputs are arrays
+ (ColumnarValue::Array(base_array),
ColumnarValue::Array(exp_array)) => {
+ let res = find_in_set(base_array, exp_array)?;
+
+ Ok(ColumnarValue::Array(res))
+ }
+ _ => {
+ internal_err!("Invalid argument types for `find_in_set`
function")
+ }
+ }
}
fn documentation(&self) -> Option<&Documentation> {
@@ -107,29 +237,24 @@ impl ScalarUDFImpl for FindInSetFunc {
}
}
-///Returns a value in the range of 1 to N if the string str is in the string
list strlist consisting of N substrings
-///A string list is a string composed of substrings separated by , characters.
-fn find_in_set(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 2 {
- return exec_err!(
- "find_in_set was called with {} arguments. It requires 2.",
- args.len()
- );
- }
- match args[0].data_type() {
+/// Returns a value in the range of 1 to N if the string `str` is in the
string list `strlist`
+/// consisting of N substrings. A string list is a string composed of
substrings separated by `,`
+/// characters.
+fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result<ArrayRef> {
+ match str.data_type() {
DataType::Utf8 => {
- let string_array = args[0].as_string::<i32>();
- let str_list_array = args[1].as_string::<i32>();
+ let string_array = str.as_string::<i32>();
+ let str_list_array = str_list.as_string::<i32>();
find_in_set_general::<Int32Type, _>(string_array, str_list_array)
}
DataType::LargeUtf8 => {
- let string_array = args[0].as_string::<i64>();
- let str_list_array = args[1].as_string::<i64>();
+ let string_array = str.as_string::<i64>();
+ let str_list_array = str_list.as_string::<i64>();
find_in_set_general::<Int64Type, _>(string_array, str_list_array)
}
DataType::Utf8View => {
- let string_array = args[0].as_string_view();
- let str_list_array = args[1].as_string_view();
+ let string_array = str.as_string_view();
+ let str_list_array = str_list.as_string_view();
find_in_set_general::<Int32Type, _>(string_array, str_list_array)
}
other => {
@@ -138,31 +263,279 @@ fn find_in_set(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}
-pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor<Item =
&'a str>>(
+pub fn find_in_set_general<'a, T, V>(
string_array: V,
str_list_array: V,
) -> Result<ArrayRef>
where
+ T: ArrowPrimitiveType,
T::Native: OffsetSizeTrait,
+ V: ArrayAccessor<Item = &'a str>,
{
let string_iter = ArrayIter::new(string_array);
let str_list_iter = ArrayIter::new(str_list_array);
- let result = string_iter
+
+ let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
+
+ string_iter
.zip(str_list_iter)
- .map(|(string, str_list)| match (string, str_list) {
- (Some(string), Some(str_list)) => {
- let mut res = 0;
- let str_set: Vec<&str> = str_list.split(',').collect();
- for (idx, str) in str_set.iter().enumerate() {
- if str == &string {
- res = idx + 1;
- break;
- }
+ .for_each(
+ |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
+ (Some(string), Some(str_list)) => {
+ let position = str_list
+ .split(',')
+ .position(|s| s == string)
+ .map_or(0, |idx| idx + 1);
+
builder.append_value(T::Native::from_usize(position).unwrap());
}
- T::Native::from_usize(res)
+ _ => builder.append_null(),
+ },
+ );
+
+ Ok(Arc::new(builder.finish()) as ArrayRef)
+}
+
+fn find_in_set_left_literal<'a, T, V>(
+ string: String,
+ str_list_array: V,
+) -> Result<ArrayRef>
+where
+ T: ArrowPrimitiveType,
+ T::Native: OffsetSizeTrait,
+ V: ArrayAccessor<Item = &'a str>,
+{
+ let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
+
+ let str_list_iter = ArrayIter::new(str_list_array);
+
+ str_list_iter.for_each(|str_list_opt| match str_list_opt {
+ Some(str_list) => {
+ let position = str_list
+ .split(',')
+ .position(|s| s == string)
+ .map_or(0, |idx| idx + 1);
+ builder.append_value(T::Native::from_usize(position).unwrap());
+ }
+ None => builder.append_null(),
+ });
+
+ Ok(Arc::new(builder.finish()) as ArrayRef)
+}
+
+fn find_in_set_right_literal<'a, T, V>(
+ string_array: V,
+ str_list: Vec<&str>,
+) -> Result<ArrayRef>
+where
+ T: ArrowPrimitiveType,
+ T::Native: OffsetSizeTrait,
+ V: ArrayAccessor<Item = &'a str>,
+{
+ let mut builder = PrimitiveArray::<T>::builder(string_array.len());
+
+ let string_iter = ArrayIter::new(string_array);
+
+ string_iter.for_each(|string_opt| match string_opt {
+ Some(string) => {
+ let position = str_list
+ .iter()
+ .position(|s| *s == string)
+ .map_or(0, |idx| idx + 1);
+ builder.append_value(T::Native::from_usize(position).unwrap());
+ }
+ None => builder.append_null(),
+ });
+
+ Ok(Arc::new(builder.finish()) as ArrayRef)
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::unicode::find_in_set::FindInSetFunc;
+ use crate::utils::test::test_function;
+ use arrow::array::{Array, Int32Array, StringArray};
+ use arrow::datatypes::DataType::Int32;
+ use datafusion_common::{Result, ScalarValue};
+ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
+ use std::sync::Arc;
+
+ #[test]
+ fn test_functions() -> Result<()> {
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
+ ],
+ Ok(Some(1)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
+ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+ "a,Д,🔥"
+ )))),
+ ],
+ Ok(Some(3)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
+ ],
+ Ok(Some(0)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+ "Apache Software Foundation"
+ )))),
+ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+ "Github,Apache Software Foundation,DataFusion"
+ )))),
+ ],
+ Ok(Some(2)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
+ ],
+ Ok(Some(0)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
+ ],
+ Ok(Some(0)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
+ ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
+ ],
+ Ok(None),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ FindInSetFunc::new(),
+ vec![
+ ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
+ ],
+ Ok(None),
+ i32,
+ Int32,
+ Int32Array
+ );
+
+ Ok(())
+ }
+
+ macro_rules! test_find_in_set {
+ ($test_name:ident, $args:expr, $expected:expr) => {
+ #[test]
+ fn $test_name() -> Result<()> {
+ let fis = crate::unicode::find_in_set();
+
+ let args = $args;
+ let expected = $expected;
+
+ let type_array = args.iter().map(|a|
a.data_type()).collect::<Vec<_>>();
+ let cardinality = args
+ .iter()
+ .fold(Option::<usize>::None, |acc, arg| match arg {
+ ColumnarValue::Scalar(_) => acc,
+ ColumnarValue::Array(a) => Some(a.len()),
+ })
+ .unwrap_or(1);
+ let return_type = fis.return_type(&type_array)?;
+ let result = fis.invoke_with_args(ScalarFunctionArgs {
+ args,
+ number_rows: cardinality,
+ return_type: &return_type,
+ });
+ assert!(result.is_ok());
+
+ let result = result?
+ .to_array(cardinality)
+ .expect("Failed to convert to array");
+ let result = result
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .expect("Failed to convert to type");
+ assert_eq!(*result, expected);
+
+ Ok(())
}
- _ => None,
- })
- .collect::<PrimitiveArray<T>>();
- Ok(Arc::new(result) as ArrayRef)
+ };
+ }
+
+ test_find_in_set!(
+ test_find_in_set_with_scalar_args,
+ vec![
+ ColumnarValue::Array(Arc::new(StringArray::from(vec![
+ "", "a", "b", "c", "d"
+ ]))),
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
+ ],
+ Int32Array::from(vec![0, 0, 1, 2, 3])
+ );
+ test_find_in_set!(
+ test_find_in_set_with_scalar_args_2,
+ vec![
+ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
+ "ApacheSoftware".to_string()
+ ))),
+ ColumnarValue::Array(Arc::new(StringArray::from(vec![
+ "a,b,c",
+ "ApacheSoftware,Github,DataFusion",
+ ""
+ ]))),
+ ],
+ Int32Array::from(vec![0, 1, 0])
+ );
+ test_find_in_set!(
+ test_find_in_set_with_scalar_args_3,
+ vec![
+ ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>;
3]))),
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
+ ],
+ Int32Array::from(vec![None::<i32>; 3])
+ );
+ test_find_in_set!(
+ test_find_in_set_with_scalar_args_4,
+ vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
+ ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>;
3]))),
+ ],
+ Int32Array::from(vec![None::<i32>; 3])
+ );
}
diff --git a/datafusion/functions/src/unicode/mod.rs
b/datafusion/functions/src/unicode/mod.rs
index e8e3eb3f4e..3c5cde3789 100644
--- a/datafusion/functions/src/unicode/mod.rs
+++ b/datafusion/functions/src/unicode/mod.rs
@@ -102,7 +102,7 @@ pub mod expr_fn {
string
),(
find_in_set,
- "Returns a value in the range of 1 to N if the string str is in the
string list strlist consisting of N substrings",
+ "Returns a value in the range of 1 to N if the string `str` is in the
string list `strlist` consisting of N substrings",
string strlist
));
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]