tlm365 commented on code in PR #14020: URL: https://github.com/apache/datafusion/pull/14020#discussion_r1911936164
########## datafusion/functions/src/unicode/find_in_set.rs: ########## @@ -138,31 +263,279 @@ fn find_in_set(args: &[ArrayRef]) -> Result<ArrayRef> { } } -pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor<Item = &'a str>>( +pub fn find_in_set_general<'a, T, V>( string_array: V, str_list_array: V, ) -> Result<ArrayRef> where + T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, + V: ArrayAccessor<Item = &'a str>, { let string_iter = ArrayIter::new(string_array); let str_list_iter = ArrayIter::new(str_list_array); - let result = string_iter + + let mut builder = PrimitiveArray::<T>::builder(string_iter.len()); + + string_iter .zip(str_list_iter) - .map(|(string, str_list)| match (string, str_list) { - (Some(string), Some(str_list)) => { - let mut res = 0; - let str_set: Vec<&str> = str_list.split(',').collect(); - for (idx, str) in str_set.iter().enumerate() { - if str == &string { - res = idx + 1; - break; - } + .for_each( + |(string_opt, str_list_opt)| match (string_opt, str_list_opt) { + (Some(string), Some(str_list)) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); } - T::Native::from_usize(res) + _ => builder.append_null(), + }, + ); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn find_in_set_left_literal<'a, T, V>( + string: String, + str_list_array: V, +) -> Result<ArrayRef> +where + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, + V: ArrayAccessor<Item = &'a str>, +{ + let mut builder = PrimitiveArray::<T>::builder(str_list_array.len()); + + let str_list_iter = ArrayIter::new(str_list_array); + + str_list_iter.for_each(|str_list_opt| match str_list_opt { + Some(str_list) => { + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn find_in_set_right_literal<'a, T, V>( + string_array: V, + str_list: Vec<&str>, +) -> Result<ArrayRef> +where + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, + V: ArrayAccessor<Item = &'a str>, +{ + let mut builder = PrimitiveArray::<T>::builder(string_array.len()); + + let string_iter = ArrayIter::new(string_array); + + string_iter.for_each(|string_opt| match string_opt { + Some(string) => { + let position = str_list + .iter() + .position(|s| *s == string) + .map_or(0, |idx| idx + 1); + builder.append_value(T::Native::from_usize(position).unwrap()); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::unicode::find_in_set::FindInSetFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array, StringArray}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))), + ], + Ok(Some(1)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "a,Д,🔥" + )))), + ], + Ok(Some(3)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache Software Foundation" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Github,Apache Software Foundation,DataFusion" + )))), + ], + Ok(Some(2)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + FindInSetFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + + Ok(()) + } + + macro_rules! test_find_in_set { + ($test_name:ident, $args:expr, $expected:expr) => { + #[test] + fn $test_name() -> Result<()> { + let fis = crate::unicode::find_in_set(); + + let args = $args; + let expected = $expected; + + let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>(); + let cardinality = args + .iter() + .fold(Option::<usize>::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); + let return_type = fis.return_type(&type_array)?; + let result = fis.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: cardinality, + return_type: &return_type, + }); + assert!(result.is_ok()); + + let result = result? + .to_array(cardinality) + .expect("Failed to convert to array"); + let result = result + .as_any() + .downcast_ref::<Int32Array>() + .expect("Failed to convert to type"); + assert_eq!(*result, expected); + + Ok(()) } - _ => None, - }) - .collect::<PrimitiveArray<T>>(); - Ok(Arc::new(result) as ArrayRef) + }; + } + + test_find_in_set!( + test_find_in_set_with_scalar_args, + vec![ + ColumnarValue::Array(Arc::new(StringArray::from(vec![ + "", "a", "b", "c", "d" + ]))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))), + ], + Int32Array::from(vec![0, 0, 1, 2, 3]) + ); + test_find_in_set!( Review Comment: > Is possible to provide a test case so we can file a ticket if this is undesired behavior? @comphead This is the test case I mentioned before. (`test_find_in_set_with_scalar_args_2/3/4`) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org