jorgecarleitao commented on a change in pull request #8975:
URL: https://github.com/apache/arrow/pull/8975#discussion_r550492883
##########
File path: rust/arrow/src/compute/kernels/comparison.rs
##########
@@ -433,47 +461,74 @@ where
/// Helper function to perform boolean lambda function on values from an array
and a scalar value using
/// SIMD.
#[cfg(simd_x86)]
-fn simd_compare_op_scalar<T, F>(
+fn simd_compare_op_scalar<T, SIMD_OP, SCALAR_OP>(
left: &PrimitiveArray<T>,
right: T::Native,
- op: F,
+ simd_op: SIMD_OP,
+ scalar_op: SCALAR_OP,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
- F: Fn(T::Simd, T::Simd) -> T::SimdMask,
+ SIMD_OP: Fn(T::Simd, T::Simd) -> T::SimdMask,
+ SCALAR_OP: Fn(T::Native, T::Native) -> bool,
{
- use std::mem;
+ use std::borrow::BorrowMut;
let len = left.len();
- let null_bit_buffer = left.data().null_buffer().cloned();
+
+ let null_bit_buffer = left
+ .data_ref()
+ .null_buffer()
+ .map(|b| b.bit_slice(left.offset(), left.len()));
+
let lanes = T::lanes();
- let mut result = MutableBuffer::new(left.len() * mem::size_of::<bool>());
+ let buffer_size = bit_util::ceil(len, 8);
+ let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size,
false);
+
+ // this is currently the case for all our datatypes and allows us to
always append full bytes
+ assert!(
+ lanes % 8 == 0,
+ "Number of vector lanes must be multiple of 8"
+ );
+ let mut left_chunks = left.values().chunks_exact(lanes);
let simd_right = T::init(right);
- let rem = len % lanes;
+ let result_remainder = left_chunks.borrow_mut().fold(
+ result.typed_data_mut(),
+ |result_slice, left_slice| {
+ let simd_left = T::load(left_slice);
+ let simd_result = simd_op(simd_left, simd_right);
- for i in (0..len - rem).step_by(lanes) {
- let simd_left = T::load(unsafe { left.value_slice(i, lanes) });
- let simd_result = op(simd_left, simd_right);
- T::bitmask(&simd_result, |b| {
- result.extend_from_slice(b);
- });
- }
+ let bitmask = T::mask_to_u64(&simd_result);
+ let bytes = bitmask.to_le_bytes();
+ &result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
- if rem > 0 {
- //Soundness
- // This is not sound because it can read past the end of
PrimitiveArray buffer (lanes is always greater than rem), see ARROW-10990
- let simd_left = T::load(unsafe { left.value_slice(len - rem, lanes) });
- let simd_result = op(simd_left, simd_right);
- let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize;
- T::bitmask(&simd_result, |b| {
- result.extend_from_slice(&b[0..rem_buffer_size]);
- });
- }
+ &mut result_slice[lanes / 8..]
+ },
+ );
+
+ let left_remainder = left_chunks.remainder();
+
+ let remainder_bitmask =
+ left_remainder
+ .iter()
+ .enumerate()
+ .fold(0_u64, |mut mask, (i, scalar_left)| {
+ let bit = if scalar_op(*scalar_left, right) {
+ 1_u64
+ } else {
+ 0_u64
+ };
+ mask |= bit << i;
+ mask
+ });
+ let remainder_mask_as_bytes =
+
&remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)];
+ result_remainder.copy_from_slice(remainder_mask_as_bytes);
let data = ArrayData::new(
DataType::Boolean,
- left.len(),
+ len,
None,
null_bit_buffer,
0,
Review comment:
This API was deprecated altogether to remove the risk of a wrong count,
which I think addresses this altogether :)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]