alamb commented on a change in pull request #984:
URL: https://github.com/apache/arrow-rs/pull/984#discussion_r765173303
##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -1200,6 +1257,29 @@ where
return compare_op_scalar!(left, right, |a, b| a == b);
}
+/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar
value.
+pub fn eq_dict_scalar<T>(
+ left: &DictionaryArray<T>,
+ right: T::Native,
+) -> Result<BooleanArray>
+where
+ T: ArrowNumericType,
+{
+ #[cfg(not(feature = "simd"))]
+ println!("{}", std::any::type_name::<T>());
+ return compare_dict_op_scalar!(left, T, right, |a, b| a == b);
+}
Review comment:
@matthewmturner this is what I was trying to say.
I think the way you have this function with a single `T` generic parameter
means one could not compare a `DictionaryArray<Int8>` (aka that has keys /
indexes of `Int8`) that had values of type `DataType::Unt16`
Here is a sketch of how this might work:
```rust
/// Perform `left == right` operation on a [`DictionaryArray`] and a numeric
scalar value.
pub fn eq_dict_scalar<T, K>(
left: &DictionaryArray<K>,
right: T::Native,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
K: ArrowNumericType,
{
// compare to the dictionary values (e.g if the dictionary is {A,
// B} and the keys are {1,0,1,1} that represents the values B, A,
// B, B.
//
// So we compare just the dictionary {A, B} values to `right` and
//
// TODO macro-ize this
let dictionary_comparison = match left.values().data_type() {
DataType::Int8 => {
eq_scalar(as_primitive_array::<T>(left.values()), right)
}
// TODO fill in Int16, Int32, etc
_ => unimplemented!("Should error: dictionary did not store values
of type T")
}?;
// Required for safety below
assert_eq!(dictionary_comparison.len(), left.values().len());
// Now, look up the dictionary for each output
let result: BooleanArray = left.keys()
.iter()
.map(|key| {
// figure out how the dictionary element at this index
// compared to the scalar
key.map(|key| {
// safety: the original array's indices were valid
// `(0 .. left.values().len()` and dictionary_comparisoon
// is the same size, checked above
unsafe {
// it would be nice to avoid checking the conversion
each time
let key = key.to_usize().expect("Dictionary index not
usize");
dictionary_comparison.value_unchecked(key)
}
})
})
.collect();
Ok(result)
}
```
##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -200,6 +201,54 @@ macro_rules! compare_op_scalar_primitive {
}};
}
+macro_rules! compare_dict_op_scalar {
+ ($left:expr, $T:ident, $right:expr, $op:expr) => {{
+ let null_bit_buffer = $left
+ .data()
+ .null_buffer()
+ .map(|b| b.bit_slice($left.offset(), $left.len()));
+
+ let values = $left.values();
+
+ let array = values
+ .as_any()
+ .downcast_ref::<PrimitiveArray<$T>>()
+ .unwrap();
+
+ // Safety:
+ // `i < $left.len()`
+ let comparison: Vec<bool> = (0..array.len())
Review comment:
I didn't think the `values()` (the dictionary size) has to the same as
the size of the overall array 🤔
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]