This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 72d7c5b fix: Scalar math operations on slices (#743)
72d7c5b is described below
commit 72d7c5b045ed2bf2572e82b973204112db9545c7
Author: Ben Chambers <[email protected]>
AuthorDate: Thu Sep 9 13:25:46 2021 -0700
fix: Scalar math operations on slices (#743)
* fix: Scalar math operations on slices
* remove conditional
---
arrow/src/compute/kernels/arithmetic.rs | 69 ++++++++++++++++++---------------
1 file changed, 38 insertions(+), 31 deletions(-)
diff --git a/arrow/src/compute/kernels/arithmetic.rs
b/arrow/src/compute/kernels/arithmetic.rs
index 931c266..b9596ee 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -29,7 +29,6 @@ use num::{One, Zero};
use crate::buffer::Buffer;
#[cfg(feature = "simd")]
use crate::buffer::MutableBuffer;
-#[cfg(not(feature = "simd"))]
use crate::compute::kernels::arity::unary;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes;
@@ -83,7 +82,10 @@ where
T::DATA_TYPE,
array.len(),
None,
- array.data_ref().null_buffer().cloned(),
+ array
+ .data_ref()
+ .null_buffer()
+ .map(|b| b.bit_slice(array.offset(), array.len())),
0,
vec![result.into()],
vec![],
@@ -132,7 +134,10 @@ where
T::DATA_TYPE,
array.len(),
None,
- array.data_ref().null_buffer().cloned(),
+ array
+ .data_ref()
+ .null_buffer()
+ .map(|b| b.bit_slice(array.offset(), array.len())),
0,
vec![result.into()],
vec![],
@@ -338,19 +343,7 @@ where
return Err(ArrowError::DivideByZero);
}
- let values = array.values().iter().map(|value| *value % modulo);
- let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
-
- let data = ArrayData::new(
- T::DATA_TYPE,
- array.len(),
- None,
- array.data_ref().null_buffer().cloned(),
- 0,
- vec![buffer],
- vec![],
- );
- Ok(PrimitiveArray::<T>::from(data))
+ Ok(unary(array, |value| value % modulo))
}
/// Scalar-divisor version of `math_divide`.
@@ -366,19 +359,7 @@ where
return Err(ArrowError::DivideByZero);
}
- let values = array.values().iter().map(|value| *value / divisor);
- let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
-
- let data = ArrayData::new(
- T::DATA_TYPE,
- array.len(),
- None,
- array.data_ref().null_buffer().cloned(),
- 0,
- vec![buffer],
- vec![],
- );
- Ok(PrimitiveArray::<T>::from(data))
+ Ok(unary(array, |value| value / divisor))
}
/// SIMD vectorized version of `math_op` above.
@@ -914,7 +895,10 @@ where
T::DATA_TYPE,
array.len(),
None,
- array.data_ref().null_buffer().cloned(),
+ array
+ .data_ref()
+ .null_buffer()
+ .map(|b| b.bit_slice(array.offset(), array.len())),
0,
vec![result.into()],
vec![],
@@ -960,7 +944,10 @@ where
T::DATA_TYPE,
array.len(),
None,
- array.data_ref().null_buffer().cloned(),
+ array
+ .data_ref()
+ .null_buffer()
+ .map(|b| b.bit_slice(array.offset(), array.len())),
0,
vec![result.into()],
vec![],
@@ -1264,6 +1251,16 @@ mod tests {
}
#[test]
+ fn test_primitive_array_divide_scalar_sliced() {
+ let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);
+ let a = a.slice(1, 4);
+ let a = as_primitive_array(&a);
+ let actual = divide_scalar(a, 3).unwrap();
+ let expected = Int32Array::from(vec![None, Some(3), Some(2), None]);
+ assert_eq!(actual, expected);
+ }
+
+ #[test]
fn test_primitive_array_modulus_scalar() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = 3;
@@ -1273,6 +1270,16 @@ mod tests {
}
#[test]
+ fn test_primitive_array_modulus_scalar_sliced() {
+ let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);
+ let a = a.slice(1, 4);
+ let a = as_primitive_array(&a);
+ let actual = modulus_scalar(a, 3).unwrap();
+ let expected = Int32Array::from(vec![None, Some(0), Some(2), None]);
+ assert_eq!(actual, expected);
+ }
+
+ #[test]
fn test_primitive_array_divide_sliced() {
let a = Int32Array::from(vec![0, 0, 0, 15, 15, 8, 1, 9, 0]);
let b = Int32Array::from(vec![0, 0, 0, 5, 6, 8, 9, 1, 0]);