This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 6d86472fa Add overflow-checking variant for primitive arithmetic
kernels and explicitly define overflow behavior (#2643)
6d86472fa is described below
commit 6d86472fa3c68986dc1813d3cb027748472ec22f
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Sep 4 02:44:36 2022 -0700
Add overflow-checking variant for primitive arithmetic kernels and
explicitly define overflow behavior (#2643)
* Add overflow-checking variant for add kernel and explicitly define
overflow behavior for add
* For subtract, multiply, divide
* Fix tests
* Fix different error message
* Fix typo
* Rename APIs and add more comments. Print values in error message.
* Add one more test to distinct divide_by_zero behavior on divide.
* Fix clippy
* Update divide doc with dividing by zero behavior for other numeric types.
* Hide ArrowNativeTypeOp
* Fix a typo
---
arrow/benches/arithmetic_kernels.rs | 4 +-
arrow/src/compute/kernels/arithmetic.rs | 262 +++++++++++++++++++++++++++++---
arrow/src/datatypes/native.rs | 106 +++++++++++++
3 files changed, 352 insertions(+), 20 deletions(-)
diff --git a/arrow/benches/arithmetic_kernels.rs
b/arrow/benches/arithmetic_kernels.rs
index 4be4a2693..10af0b543 100644
--- a/arrow/benches/arithmetic_kernels.rs
+++ b/arrow/benches/arithmetic_kernels.rs
@@ -55,13 +55,13 @@ fn bench_multiply(arr_a: &ArrayRef, arr_b: &ArrayRef) {
fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) {
let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
- criterion::black_box(divide(arr_a, arr_b).unwrap());
+ criterion::black_box(divide_checked(arr_a, arr_b).unwrap());
}
fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) {
let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
- criterion::black_box(divide_unchecked(arr_a, arr_b).unwrap());
+ criterion::black_box(divide(arr_a, arr_b).unwrap());
}
fn bench_divide_scalar(array: &ArrayRef, divisor: f32) {
diff --git a/arrow/src/compute/kernels/arithmetic.rs
b/arrow/src/compute/kernels/arithmetic.rs
index fff687e18..53f48570d 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -35,8 +35,9 @@ use crate::compute::unary_dyn;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes;
use crate::datatypes::{
- ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType,
- IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
+ native_op::ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
DataType,
+ Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit,
+ IntervalYearMonthType,
};
use crate::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type,
@@ -103,6 +104,106 @@ where
Ok(PrimitiveArray::<LT>::from(data))
}
+/// This is similar to `math_op` as it performs given operation between two
input primitive arrays.
+/// But the given operation can return `None` if overflow is detected. For the
case, this function
+/// returns an `Err`.
+fn math_checked_op<LT, RT, F>(
+ left: &PrimitiveArray<LT>,
+ right: &PrimitiveArray<RT>,
+ op: F,
+) -> Result<PrimitiveArray<LT>>
+where
+ LT: ArrowNumericType,
+ RT: ArrowNumericType,
+ F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
+{
+ if left.len() != right.len() {
+ return Err(ArrowError::ComputeError(
+ "Cannot perform math operation on arrays of different
length".to_string(),
+ ));
+ }
+
+ let left_iter = ArrayIter::new(left);
+ let right_iter = ArrayIter::new(right);
+
+ let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> =
left_iter
+ .into_iter()
+ .zip(right_iter.into_iter())
+ .map(|(l, r)| {
+ if let (Some(l), Some(r)) = (l, r) {
+ let result = op(l, r);
+ if let Some(r) = result {
+ Ok(Some(r))
+ } else {
+ // Overflow
+ Err(ArrowError::ComputeError(format!(
+ "Overflow happened on: {:?}, {:?}",
+ l, r
+ )))
+ }
+ } else {
+ Ok(None)
+ }
+ })
+ .collect();
+
+ let values = values?;
+
+ Ok(PrimitiveArray::<LT>::from_iter(values))
+}
+
+/// This is similar to `math_checked_op` but just for divide op.
+fn math_checked_divide<LT, RT, F>(
+ left: &PrimitiveArray<LT>,
+ right: &PrimitiveArray<RT>,
+ op: F,
+) -> Result<PrimitiveArray<LT>>
+where
+ LT: ArrowNumericType,
+ RT: ArrowNumericType,
+ RT::Native: One + Zero,
+ F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
+{
+ if left.len() != right.len() {
+ return Err(ArrowError::ComputeError(
+ "Cannot perform math operation on arrays of different
length".to_string(),
+ ));
+ }
+
+ let left_iter = ArrayIter::new(left);
+ let right_iter = ArrayIter::new(right);
+
+ let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> =
left_iter
+ .into_iter()
+ .zip(right_iter.into_iter())
+ .map(|(l, r)| {
+ if let (Some(l), Some(r)) = (l, r) {
+ let result = op(l, r);
+ if let Some(r) = result {
+ Ok(Some(r))
+ } else if r.is_zero() {
+ Err(ArrowError::ComputeError(format!(
+ "DivideByZero on: {:?}, {:?}",
+ l, r
+ )))
+ } else {
+ // Overflow
+ Err(ArrowError::ComputeError(format!(
+ "Overflow happened on: {:?}, {:?}",
+ l, r
+ )))
+ }
+ } else {
+ Ok(None)
+ }
+ })
+ .collect();
+
+ let values = values?;
+
+ Ok(PrimitiveArray::<LT>::from_iter(values))
+}
+
/// Helper function for operations where a valid `0` on the right array should
/// result in an [ArrowError::DivideByZero], namely the division and modulo
operations
///
@@ -760,15 +861,34 @@ where
/// Perform `left + right` operation on two arrays. If either left or right
value is null
/// then the result is also null.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `add_checked` instead.
pub fn add<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: ArrowNumericType,
- T::Native: Add<Output = T::Native>,
+ T::Native: ArrowNativeTypeOp,
+{
+ math_op(left, right, |a, b| a.add_wrapping(b))
+}
+
+/// Perform `left + right` operation on two arrays. If either left or right
value is null
+/// then the result is also null. Once
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `add` instead.
+pub fn add_checked<T>(
+ left: &PrimitiveArray<T>,
+ right: &PrimitiveArray<T>,
+) -> Result<PrimitiveArray<T>>
+where
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
- math_op(left, right, |a, b| a + b)
+ math_checked_op(left, right, |a, b| a.add_checked(b))
}
/// Perform `left + right` operation on two arrays. If either left or right
value is null
@@ -856,15 +976,34 @@ where
/// Perform `left - right` operation on two arrays. If either left or right
value is null
/// then the result is also null.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `subtract_checked` instead.
pub fn subtract<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
- T::Native: Sub<Output = T::Native>,
+ T::Native: ArrowNativeTypeOp,
{
- math_op(left, right, |a, b| a - b)
+ math_op(left, right, |a, b| a.sub_wrapping(b))
+}
+
+/// Perform `left - right` operation on two arrays. If either left or right
value is null
+/// then the result is also null.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `subtract` instead.
+pub fn subtract_checked<T>(
+ left: &PrimitiveArray<T>,
+ right: &PrimitiveArray<T>,
+) -> Result<PrimitiveArray<T>>
+where
+ T: datatypes::ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
+{
+ math_checked_op(left, right, |a, b| a.sub_checked(b))
}
/// Perform `left - right` operation on two arrays. If either left or right
value is null
@@ -933,15 +1072,34 @@ where
/// Perform `left * right` operation on two arrays. If either left or right
value is null
/// then the result is also null.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `multiply_check` instead.
pub fn multiply<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
- T::Native: Mul<Output = T::Native>,
+ T::Native: ArrowNativeTypeOp,
{
- math_op(left, right, |a, b| a * b)
+ math_op(left, right, |a, b| a.mul_wrapping(b))
+}
+
+/// Perform `left * right` operation on two arrays. If either left or right
value is null
+/// then the result is also null.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `multiply` instead.
+pub fn multiply_checked<T>(
+ left: &PrimitiveArray<T>,
+ right: &PrimitiveArray<T>,
+) -> Result<PrimitiveArray<T>>
+where
+ T: datatypes::ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
+{
+ math_checked_op(left, right, |a, b| a.mul_checked(b))
}
/// Perform `left * right` operation on two arrays. If either left or right
value is null
@@ -1013,18 +1171,21 @@ where
/// Perform `left / right` operation on two arrays. If either left or right
value is null
/// then the result is also null. If any right hand value is zero then the
result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
-pub fn divide<T>(
+///
+/// When `simd` feature is not enabled. This detects overflow and returns an
`Err` for that.
+/// For an non-overflow-checking variant, use `divide` instead.
+pub fn divide_checked<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
- T::Native: Div<Output = T::Native> + Zero + One,
+ T::Native: ArrowNativeTypeOp + Zero + One,
{
#[cfg(feature = "simd")]
return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a,
b| a / b);
#[cfg(not(feature = "simd"))]
- return math_checked_divide_op(left, right, |a, b| a / b);
+ return math_checked_divide(left, right, |a, b| a.div_checked(b));
}
/// Perform `left / right` operation on two arrays. If either left or right
value is null
@@ -1040,17 +1201,21 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array)
-> Result<ArrayRef> {
}
/// Perform `left / right` operation on two arrays without checking for
division by zero.
-/// The result of dividing by zero follows normal floating point rules.
+/// For floating point types, the result of dividing by zero follows normal
floating point
+/// rules. For other numeric types, dividing by zero will panic,
/// If either left or right value is null then the result is also null. If any
right hand value is zero then the result of this
-pub fn divide_unchecked<T>(
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `divide_checked` instead.
+pub fn divide<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
where
- T: datatypes::ArrowFloatNumericType,
- T::Native: Div<Output = T::Native>,
+ T: datatypes::ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
- math_op(left, right, |a, b| a / b)
+ math_op(left, right, |a, b| a.div_wrapping(b))
}
/// Modulus every value in an array by a scalar. If any value in the array is
null then the
@@ -1769,7 +1934,7 @@ mod tests {
fn test_primitive_array_divide_with_nulls() {
let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1),
Some(9), None]);
let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9),
None, None]);
- let c = divide(&a, &b).unwrap();
+ let c = divide_checked(&a, &b).unwrap();
assert_eq!(3, c.value(0));
assert!(c.is_null(1));
assert_eq!(1, c.value(2));
@@ -1854,7 +2019,7 @@ mod tests {
let b = b.slice(8, 6);
let b = b.as_any().downcast_ref::<Int32Array>().unwrap();
- let c = divide(a, b).unwrap();
+ let c = divide_checked(a, b).unwrap();
assert_eq!(6, c.len());
assert_eq!(3, c.value(0));
assert!(c.is_null(1));
@@ -1919,6 +2084,14 @@ mod tests {
#[test]
#[should_panic(expected = "DivideByZero")]
+ fn test_primitive_array_divide_by_zero_with_checked() {
+ let a = Int32Array::from(vec![15]);
+ let b = Int32Array::from(vec![0]);
+ divide_checked(&a, &b).unwrap();
+ }
+
+ #[test]
+ #[should_panic(expected = "attempt to divide by zero")]
fn test_primitive_array_divide_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
@@ -2019,4 +2192,57 @@ mod tests {
let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]);
assert_eq!(expected, actual);
}
+
+ #[test]
+ fn test_primitive_add_wrapping_overflow() {
+ let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+ let b = Int32Array::from(vec![1, 1]);
+
+ let wrapped = add(&a, &b);
+ let expected = Int32Array::from(vec![-2147483648, -2147483647]);
+ assert_eq!(expected, wrapped.unwrap());
+
+ let overflow = add_checked(&a, &b);
+ overflow.expect_err("overflow should be detected");
+ }
+
+ #[test]
+ fn test_primitive_subtract_wrapping_overflow() {
+ let a = Int32Array::from(vec![-2]);
+ let b = Int32Array::from(vec![i32::MAX]);
+
+ let wrapped = subtract(&a, &b);
+ let expected = Int32Array::from(vec![i32::MAX]);
+ assert_eq!(expected, wrapped.unwrap());
+
+ let overflow = subtract_checked(&a, &b);
+ overflow.expect_err("overflow should be detected");
+ }
+
+ #[test]
+ fn test_primitive_mul_wrapping_overflow() {
+ let a = Int32Array::from(vec![10]);
+ let b = Int32Array::from(vec![i32::MAX]);
+
+ let wrapped = multiply(&a, &b);
+ let expected = Int32Array::from(vec![-10]);
+ assert_eq!(expected, wrapped.unwrap());
+
+ let overflow = multiply_checked(&a, &b);
+ overflow.expect_err("overflow should be detected");
+ }
+
+ #[test]
+ #[cfg(not(feature = "simd"))]
+ fn test_primitive_div_wrapping_overflow() {
+ let a = Int32Array::from(vec![i32::MIN]);
+ let b = Int32Array::from(vec![-1]);
+
+ let wrapped = divide(&a, &b);
+ let expected = Int32Array::from(vec![-2147483648]);
+ assert_eq!(expected, wrapped.unwrap());
+
+ let overflow = divide_checked(&a, &b);
+ overflow.expect_err("overflow should be detected");
+ }
}
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index 207e8cb40..444f2b27d 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -114,6 +114,112 @@ pub trait ArrowPrimitiveType: 'static {
}
}
+pub(crate) mod native_op {
+ use super::ArrowNativeType;
+ use std::ops::{Add, Div, Mul, Sub};
+
+ /// Trait for ArrowNativeType to provide overflow-checking and
non-overflow-checking
+ /// variants for arithmetic operations. For floating point types, this
provides some
+ /// default implementations. Integer types that need to deal with overflow
can implement
+ /// this trait.
+ ///
+ /// The APIs with `_wrapping` suffix are the variant of
non-overflow-checking. If overflow
+ /// occurred, they will supposedly wrap around the boundary of the type.
+ ///
+ /// The APIs with `_checked` suffix are the variant of overflow-checking
which return `None`
+ /// if overflow occurred.
+ pub trait ArrowNativeTypeOp:
+ ArrowNativeType
+ + Add<Output = Self>
+ + Sub<Output = Self>
+ + Mul<Output = Self>
+ + Div<Output = Self>
+ {
+ fn add_checked(self, rhs: Self) -> Option<Self> {
+ Some(self + rhs)
+ }
+
+ fn add_wrapping(self, rhs: Self) -> Self {
+ self + rhs
+ }
+
+ fn sub_checked(self, rhs: Self) -> Option<Self> {
+ Some(self - rhs)
+ }
+
+ fn sub_wrapping(self, rhs: Self) -> Self {
+ self - rhs
+ }
+
+ fn mul_checked(self, rhs: Self) -> Option<Self> {
+ Some(self * rhs)
+ }
+
+ fn mul_wrapping(self, rhs: Self) -> Self {
+ self * rhs
+ }
+
+ fn div_checked(self, rhs: Self) -> Option<Self> {
+ Some(self / rhs)
+ }
+
+ fn div_wrapping(self, rhs: Self) -> Self {
+ self / rhs
+ }
+ }
+}
+
+macro_rules! native_type_op {
+ ($t:tt) => {
+ impl native_op::ArrowNativeTypeOp for $t {
+ fn add_checked(self, rhs: Self) -> Option<Self> {
+ self.checked_add(rhs)
+ }
+
+ fn add_wrapping(self, rhs: Self) -> Self {
+ self.wrapping_add(rhs)
+ }
+
+ fn sub_checked(self, rhs: Self) -> Option<Self> {
+ self.checked_sub(rhs)
+ }
+
+ fn sub_wrapping(self, rhs: Self) -> Self {
+ self.wrapping_sub(rhs)
+ }
+
+ fn mul_checked(self, rhs: Self) -> Option<Self> {
+ self.checked_mul(rhs)
+ }
+
+ fn mul_wrapping(self, rhs: Self) -> Self {
+ self.wrapping_mul(rhs)
+ }
+
+ fn div_checked(self, rhs: Self) -> Option<Self> {
+ self.checked_div(rhs)
+ }
+
+ fn div_wrapping(self, rhs: Self) -> Self {
+ self.wrapping_div(rhs)
+ }
+ }
+ };
+}
+
+native_type_op!(i8);
+native_type_op!(i16);
+native_type_op!(i32);
+native_type_op!(i64);
+native_type_op!(u8);
+native_type_op!(u16);
+native_type_op!(u32);
+native_type_op!(u64);
+
+impl native_op::ArrowNativeTypeOp for f16 {}
+impl native_op::ArrowNativeTypeOp for f32 {}
+impl native_op::ArrowNativeTypeOp for f64 {}
+
impl private::Sealed for i8 {}
impl ArrowNativeType for i8 {
#[inline]