This is an automated email from the ASF dual-hosted git repository.
remziy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 191eaef09 Add modulus ops into `ArrowNativeTypeOp` (#2756)
191eaef09 is described below
commit 191eaef0906f61dc0dc6ca6cea96a99f86e6c5a4
Author: Remzi Yang <[email protected]>
AuthorDate: Tue Oct 4 07:26:24 2022 +0800
Add modulus ops into `ArrowNativeTypeOp` (#2756)
* add 3 mod ops and tests
Signed-off-by: remzi <[email protected]>
* fix simd error
Signed-off-by: remzi <[email protected]>
* remove_mod_divide_by_zero
Signed-off-by: remzi <[email protected]>
* overflow panic simd
Signed-off-by: remzi <[email protected]>
* address comment
Signed-off-by: remzi <[email protected]>
Signed-off-by: remzi <[email protected]>
---
arrow/src/compute/kernels/arithmetic.rs | 64 +++++++++++++++++++++++++++------
arrow/src/datatypes/native.rs | 32 ++++++++++++++++-
2 files changed, 85 insertions(+), 11 deletions(-)
diff --git a/arrow/src/compute/kernels/arithmetic.rs
b/arrow/src/compute/kernels/arithmetic.rs
index b2e95ad5e..1e6e55248 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -22,7 +22,7 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
-use std::ops::{Div, Neg, Rem};
+use std::ops::{Div, Neg};
use num::{One, Zero};
@@ -182,7 +182,7 @@ fn simd_checked_modulus<T: ArrowNumericType>(
right: T::Simd,
) -> Result<T::Simd>
where
- T::Native: One + Zero,
+ T::Native: ArrowNativeTypeOp + One,
{
let zero = T::init(T::Native::zero());
let one = T::init(T::Native::one());
@@ -305,7 +305,7 @@ fn simd_checked_divide_op<T, SI, SC>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
- T::Native: One + Zero,
+ T::Native: ArrowNativeTypeOp,
SI: Fn(Option<u64>, T::Simd, T::Simd) -> Result<T::Simd>,
SC: Fn(T::Native, T::Native) -> T::Native,
{
@@ -1301,7 +1301,7 @@ pub fn modulus<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
- T::Native: Rem<Output = T::Native> + Zero + One,
+ T::Native: ArrowNativeTypeOp + One,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_modulus::<T>,
|a, b| {
@@ -1312,7 +1312,7 @@ where
if b.is_zero() {
Err(ArrowError::DivideByZero)
} else {
- Ok(a % b)
+ Ok(a.mod_wrapping(b))
}
});
}
@@ -1507,13 +1507,13 @@ pub fn modulus_scalar<T>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
- T::Native: Rem<Output = T::Native> + Zero,
+ T::Native: ArrowNativeTypeOp,
{
if modulo.is_zero() {
return Err(ArrowError::DivideByZero);
}
- Ok(unary(array, |a| a % modulo))
+ Ok(unary(array, |a| a.mod_wrapping(modulo)))
}
/// Divide every value in an array by a scalar. If any value in the array is
null then the
@@ -2117,7 +2117,7 @@ mod tests {
}
#[test]
- fn test_primitive_array_modulus() {
+ fn test_int_array_modulus() {
let a = Int32Array::from(vec![15, 15, 8, 1, 9]);
let b = Int32Array::from(vec![5, 6, 8, 9, 1]);
let c = modulus(&a, &b).unwrap();
@@ -2128,6 +2128,34 @@ mod tests {
assert_eq!(0, c.value(4));
}
+ #[test]
+ #[should_panic(
+ expected = "called `Result::unwrap()` on an `Err` value: DivideByZero"
+ )]
+ fn test_int_array_modulus_divide_by_zero() {
+ let a = Int32Array::from(vec![1]);
+ let b = Int32Array::from(vec![0]);
+ modulus(&a, &b).unwrap();
+ }
+
+ #[test]
+ #[cfg(not(feature = "simd"))]
+ fn test_int_array_modulus_overflow_wrapping() {
+ let a = Int32Array::from(vec![i32::MIN]);
+ let b = Int32Array::from(vec![-1]);
+ let result = modulus(&a, &b).unwrap();
+ assert_eq!(0, result.value(0))
+ }
+
+ #[test]
+ #[cfg(feature = "simd")]
+ #[should_panic(expected = "attempt to calculate the remainder with
overflow")]
+ fn test_int_array_modulus_overflow_panic() {
+ let a = Int32Array::from(vec![i32::MIN]);
+ let b = Int32Array::from(vec![-1]);
+ let _ = modulus(&a, &b).unwrap();
+ }
+
#[test]
fn test_primitive_array_divide_scalar() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
@@ -2190,7 +2218,7 @@ mod tests {
}
#[test]
- fn test_primitive_array_modulus_scalar() {
+ fn test_int_array_modulus_scalar() {
let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
let b = 3;
let c = modulus_scalar(&a, b).unwrap();
@@ -2199,7 +2227,7 @@ mod tests {
}
#[test]
- fn test_primitive_array_modulus_scalar_sliced() {
+ fn test_int_array_modulus_scalar_sliced() {
let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);
let a = a.slice(1, 4);
let a = as_primitive_array(&a);
@@ -2208,6 +2236,22 @@ mod tests {
assert_eq!(actual, expected);
}
+ #[test]
+ #[should_panic(
+ expected = "called `Result::unwrap()` on an `Err` value: DivideByZero"
+ )]
+ fn test_int_array_modulus_scalar_divide_by_zero() {
+ let a = Int32Array::from(vec![1]);
+ modulus_scalar(&a, 0).unwrap();
+ }
+
+ #[test]
+ fn test_int_array_modulus_scalar_overflow_wrapping() {
+ let a = Int32Array::from(vec![i32::MIN]);
+ let result = modulus_scalar(&a, -1).unwrap();
+ assert_eq!(0, result.value(0))
+ }
+
#[test]
fn test_primitive_array_divide_sliced() {
let a = Int32Array::from(vec![0, 0, 0, 15, 15, 8, 1, 9, 0]);
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index 6ab82688e..654b93950 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -26,7 +26,7 @@ pub(crate) mod native_op {
use super::ArrowNativeType;
use crate::error::{ArrowError, Result};
use num::Zero;
- use std::ops::{Add, Div, Mul, Sub};
+ use std::ops::{Add, Div, Mul, Rem, Sub};
/// Trait for ArrowNativeType to provide overflow-checking and
non-overflow-checking
/// variants for arithmetic operations. For floating point types, this
provides some
@@ -44,6 +44,7 @@ pub(crate) mod native_op {
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ + Rem<Output = Self>
+ Zero
{
fn add_checked(self, rhs: Self) -> Result<Self> {
@@ -81,6 +82,18 @@ pub(crate) mod native_op {
fn div_wrapping(self, rhs: Self) -> Self {
self / rhs
}
+
+ fn mod_checked(self, rhs: Self) -> Result<Self> {
+ if rhs.is_zero() {
+ Err(ArrowError::DivideByZero)
+ } else {
+ Ok(self % rhs)
+ }
+ }
+
+ fn mod_wrapping(self, rhs: Self) -> Self {
+ self % rhs
+ }
}
}
@@ -142,6 +155,23 @@ macro_rules! native_type_op {
fn div_wrapping(self, rhs: Self) -> Self {
self.wrapping_div(rhs)
}
+
+ fn mod_checked(self, rhs: Self) -> Result<Self> {
+ if rhs.is_zero() {
+ Err(ArrowError::DivideByZero)
+ } else {
+ self.checked_rem(rhs).ok_or_else(|| {
+ ArrowError::ComputeError(format!(
+ "Overflow happened on: {:?} % {:?}",
+ self, rhs
+ ))
+ })
+ }
+ }
+
+ fn mod_wrapping(self, rhs: Self) -> Self {
+ self.wrapping_rem(rhs)
+ }
}
};
}