alamb commented on a change in pull request #1127: URL: https://github.com/apache/arrow-rs/pull/1127#discussion_r777220437
########## File path: arrow/src/array/array.rs ########## @@ -227,6 +227,26 @@ pub trait Array: fmt::Debug + Send + Sync + JsonEqual { /// A reference-counted reference to a generic `Array`. pub type ArrayRef = Arc<dyn Array>; +impl Array for ArrayRef { Review comment: this needs to be in its own PR I think ########## File path: arrow/src/compute/kernels/comparison.rs ########## @@ -1041,280 +1037,155 @@ macro_rules! dyn_compare_utf8_scalar { ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ match $KT.as_ref() { DataType::UInt8 => { - let left = as_dictionary_array::<UInt8Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt8Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt16 => { - let left = as_dictionary_array::<UInt16Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt16Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt32 => { - let left = as_dictionary_array::<UInt32Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt32Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt64 => { - let left = as_dictionary_array::<UInt64Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt64Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int8 => { - let left = as_dictionary_array::<Int8Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int8Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int16 => { - let left = as_dictionary_array::<Int16Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int16Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int32 => { - let left = as_dictionary_array::<Int32Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int32Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int64 => { - let left = as_dictionary_array::<Int64Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int64Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } - _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), + _ => Err(ArrowError::ComputeError(format!( + "Unsupported dictionary key type {:?}", + $KT.as_ref() + ))), } }}; } /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values -pub fn eq_dyn_scalar<T>(left: Arc<dyn Array>, right: T) -> Result<BooleanArray> +pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray> where - T: TryInto<i128> + Copy + std::fmt::Debug, + T: num::ToPrimitive + Copy + std::fmt::Debug, Review comment: Here is API change 2: Use `num::ToPrimitive` rather than `TryInto<i128>` ########## File path: arrow/src/array/cast.rs ########## @@ -69,6 +69,29 @@ pub fn as_generic_binary_array<S: BinaryOffsetSizeTrait>( .expect("Unable to downcast to binary array") } +/// Downcast `$arr` to the array of type `$arrty`. Panic's if the Review comment: not sure this is needed yet ########## File path: arrow/src/compute/kernels/comparison.rs ########## @@ -1041,280 +1037,155 @@ macro_rules! dyn_compare_utf8_scalar { ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ match $KT.as_ref() { DataType::UInt8 => { - let left = as_dictionary_array::<UInt8Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt8Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt16 => { - let left = as_dictionary_array::<UInt16Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt16Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt32 => { - let left = as_dictionary_array::<UInt32Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt32Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt64 => { - let left = as_dictionary_array::<UInt64Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt64Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int8 => { - let left = as_dictionary_array::<Int8Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int8Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int16 => { - let left = as_dictionary_array::<Int16Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int16Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int32 => { - let left = as_dictionary_array::<Int32Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int32Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int64 => { - let left = as_dictionary_array::<Int64Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int64Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } - _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), + _ => Err(ArrowError::ComputeError(format!( + "Unsupported dictionary key type {:?}", + $KT.as_ref() + ))), } }}; } /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values -pub fn eq_dyn_scalar<T>(left: Arc<dyn Array>, right: T) -> Result<BooleanArray> +pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray> where - T: TryInto<i128> + Copy + std::fmt::Debug, + T: num::ToPrimitive + Copy + std::fmt::Debug, { match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { Review comment: these type checks are redundant with the check done in `dyn_compare_scalar!` already (and they don't include `Float32` and `Float64` so I removed them ########## File path: arrow/src/compute/kernels/comparison.rs ########## @@ -1041,280 +1037,155 @@ macro_rules! dyn_compare_utf8_scalar { ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ match $KT.as_ref() { DataType::UInt8 => { - let left = as_dictionary_array::<UInt8Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt8Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt16 => { - let left = as_dictionary_array::<UInt16Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt16Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt32 => { - let left = as_dictionary_array::<UInt32Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt32Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::UInt64 => { - let left = as_dictionary_array::<UInt64Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<UInt64Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int8 => { - let left = as_dictionary_array::<Int8Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int8Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int16 => { - let left = as_dictionary_array::<Int16Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int16Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int32 => { - let left = as_dictionary_array::<Int32Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int32Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } DataType::Int64 => { - let left = as_dictionary_array::<Int64Type>($LEFT); - let values = as_string_array(left.values()); + let left = array_downcast!($LEFT, DictionaryArray<Int64Type>); + let values = array_downcast!(left.values(), StringArray); unpack_dict_comparison(left, $OP(values, $RIGHT)?) } - _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), + _ => Err(ArrowError::ComputeError(format!( + "Unsupported dictionary key type {:?}", + $KT.as_ref() + ))), } }}; } /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values -pub fn eq_dyn_scalar<T>(left: Arc<dyn Array>, right: T) -> Result<BooleanArray> +pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray> Review comment: Here is API change 1: take `&dyn Array` rather than `Arc<dyn Array>` so these kernels can be used on arrays that are not in arcs and the caller doesn't have to increment the reference count ########## File path: arrow/src/compute/kernels/comparison.rs ########## @@ -3214,18 +3100,34 @@ mod tests { builder.append(123).unwrap(); builder.append_null().unwrap(); builder.append(23).unwrap(); - let array = Arc::new(builder.finish()); - let a_eq = eq_dyn_scalar(array, 123).unwrap(); + let array = builder.finish(); + let a_eq = eq_dyn_scalar(&array, 123).unwrap(); assert_eq!( a_eq, BooleanArray::from(vec![Some(true), None, Some(false)]) ); } + + #[test] + fn test_eq_dyn_scalar_float() { + let array: Float32Array = vec![6.0, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(false)], + ); + assert_eq!(eq_dyn_scalar(&array, 8).unwrap(), expected); + + let array: ArrayRef = Arc::new(array); + let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + assert_eq!(eq_dyn_scalar(&array, 8).unwrap(), expected); + } + #[test] fn test_lt_dyn_scalar() { let array = Int32Array::from(vec![6, 7, 8, 8, 10]); - let array = Arc::new(array); - let a_eq = lt_dyn_scalar(array, 8).unwrap(); + let a_eq = lt_dyn_scalar(&array, 8).unwrap(); Review comment: The changes to the test illustrate why the API change to not pass in an `Arc` is nicer, I think ########## File path: arrow/src/compute/kernels/comparison.rs ########## @@ -890,149 +890,145 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( compare_op_scalar!(left, right, |a, b| a >= b) } +/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. +/// Type of expression is `Result<.., ArrowError>` +macro_rules! try_to_type { + ($RIGHT: expr, $TY: ident) => {{ + $RIGHT.$TY().ok_or_else(|| { + ArrowError::ComputeError(format!( + "Could not convert {} with {}", + stringify!($RIGHT), + stringify!($TY) + )) + }) + }}; +} + macro_rules! dyn_compare_scalar { // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{ - let right: i128 = $RIGHT.try_into().map_err(|_| { - ArrowError::ComputeError(String::from("Can not convert scalar to i128")) - })?; match $LEFT.data_type() { DataType::Int8 => { - let right: i8 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from("Can not convert scalar to i8")) - })?; - let left = as_primitive_array::<Int8Type>($LEFT); + let right = try_to_type!($RIGHT, to_i8)?; + let left = array_downcast!($LEFT, Int8Array); $OP::<Int8Type>(left, right) } DataType::Int16 => { - let right: i16 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from( - "Can not convert scalar to i16", - )) - })?; - let left = as_primitive_array::<Int16Type>($LEFT); + let right = try_to_type!($RIGHT, to_i16)?; + let left = array_downcast!($LEFT, Int16Array); $OP::<Int16Type>(left, right) } DataType::Int32 => { - let right: i32 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from( - "Can not convert scalar to i32", - )) - })?; - let left = as_primitive_array::<Int32Type>($LEFT); + let right = try_to_type!($RIGHT, to_i32)?; + let left = array_downcast!($LEFT, Int32Array); $OP::<Int32Type>(left, right) } DataType::Int64 => { - let right: i64 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from( - "Can not convert scalar to i64", - )) - })?; - let left = as_primitive_array::<Int64Type>($LEFT); + let right = try_to_type!($RIGHT, to_i64)?; + let left = array_downcast!($LEFT, Int64Array); $OP::<Int64Type>(left, right) } DataType::UInt8 => { - let right: u8 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from("Can not convert scalar to u8")) - })?; - let left = as_primitive_array::<UInt8Type>($LEFT); + let right = try_to_type!($RIGHT, to_u8)?; + let left = array_downcast!($LEFT, UInt8Array); $OP::<UInt8Type>(left, right) } DataType::UInt16 => { - let right: u16 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from( - "Can not convert scalar to u16", - )) - })?; - let left = as_primitive_array::<UInt16Type>($LEFT); + let right = try_to_type!($RIGHT, to_u16)?; + let left = array_downcast!($LEFT, UInt16Array); $OP::<UInt16Type>(left, right) } DataType::UInt32 => { - let right: u32 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from( - "Can not convert scalar to u32", - )) - })?; - let left = as_primitive_array::<UInt32Type>($LEFT); + let right = try_to_type!($RIGHT, to_u32)?; + let left = array_downcast!($LEFT, UInt32Array); $OP::<UInt32Type>(left, right) } DataType::UInt64 => { - let right: u64 = right.try_into().map_err(|_| { - ArrowError::ComputeError(String::from( - "Can not convert scalar to u64", - )) - })?; - let left = as_primitive_array::<UInt64Type>($LEFT); + let right = try_to_type!($RIGHT, to_u64)?; + let left = array_downcast!($LEFT, UInt64Array); $OP::<UInt64Type>(left, right) } - _ => Err(ArrowError::ComputeError(String::from( - "Unsupported data type", + DataType::Float32 => { Review comment: here is support for `Float32` and `Float64` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org