This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d86472fa Add overflow-checking variant for primitive arithmetic 
kernels and explicitly define overflow behavior (#2643)
6d86472fa is described below

commit 6d86472fa3c68986dc1813d3cb027748472ec22f
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Sep 4 02:44:36 2022 -0700

    Add overflow-checking variant for primitive arithmetic kernels and 
explicitly define overflow behavior (#2643)
    
    * Add overflow-checking variant for add kernel and explicitly define 
overflow behavior for add
    
    * For subtract, multiply, divide
    
    * Fix tests
    
    * Fix different error message
    
    * Fix typo
    
    * Rename APIs and add more comments. Print values in error message.
    
    * Add one more test to distinct divide_by_zero behavior on divide.
    
    * Fix clippy
    
    * Update divide doc with dividing by zero behavior for other numeric types.
    
    * Hide ArrowNativeTypeOp
    
    * Fix a typo
---
 arrow/benches/arithmetic_kernels.rs     |   4 +-
 arrow/src/compute/kernels/arithmetic.rs | 262 +++++++++++++++++++++++++++++---
 arrow/src/datatypes/native.rs           | 106 +++++++++++++
 3 files changed, 352 insertions(+), 20 deletions(-)

diff --git a/arrow/benches/arithmetic_kernels.rs 
b/arrow/benches/arithmetic_kernels.rs
index 4be4a2693..10af0b543 100644
--- a/arrow/benches/arithmetic_kernels.rs
+++ b/arrow/benches/arithmetic_kernels.rs
@@ -55,13 +55,13 @@ fn bench_multiply(arr_a: &ArrayRef, arr_b: &ArrayRef) {
 fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) {
     let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
     let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(divide(arr_a, arr_b).unwrap());
+    criterion::black_box(divide_checked(arr_a, arr_b).unwrap());
 }
 
 fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) {
     let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
     let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(divide_unchecked(arr_a, arr_b).unwrap());
+    criterion::black_box(divide(arr_a, arr_b).unwrap());
 }
 
 fn bench_divide_scalar(array: &ArrayRef, divisor: f32) {
diff --git a/arrow/src/compute/kernels/arithmetic.rs 
b/arrow/src/compute/kernels/arithmetic.rs
index fff687e18..53f48570d 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -35,8 +35,9 @@ use crate::compute::unary_dyn;
 use crate::compute::util::combine_option_bitmap;
 use crate::datatypes;
 use crate::datatypes::{
-    ArrowNumericType, DataType, Date32Type, Date64Type, IntervalDayTimeType,
-    IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
+    native_op::ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, 
DataType,
+    Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, 
IntervalUnit,
+    IntervalYearMonthType,
 };
 use crate::datatypes::{
     Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, 
UInt16Type,
@@ -103,6 +104,106 @@ where
     Ok(PrimitiveArray::<LT>::from(data))
 }
 
+/// This is similar to `math_op` as it performs given operation between two 
input primitive arrays.
+/// But the given operation can return `None` if overflow is detected. For the 
case, this function
+/// returns an `Err`.
+fn math_checked_op<LT, RT, F>(
+    left: &PrimitiveArray<LT>,
+    right: &PrimitiveArray<RT>,
+    op: F,
+) -> Result<PrimitiveArray<LT>>
+where
+    LT: ArrowNumericType,
+    RT: ArrowNumericType,
+    F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
+{
+    if left.len() != right.len() {
+        return Err(ArrowError::ComputeError(
+            "Cannot perform math operation on arrays of different 
length".to_string(),
+        ));
+    }
+
+    let left_iter = ArrayIter::new(left);
+    let right_iter = ArrayIter::new(right);
+
+    let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = 
left_iter
+        .into_iter()
+        .zip(right_iter.into_iter())
+        .map(|(l, r)| {
+            if let (Some(l), Some(r)) = (l, r) {
+                let result = op(l, r);
+                if let Some(r) = result {
+                    Ok(Some(r))
+                } else {
+                    // Overflow
+                    Err(ArrowError::ComputeError(format!(
+                        "Overflow happened on: {:?}, {:?}",
+                        l, r
+                    )))
+                }
+            } else {
+                Ok(None)
+            }
+        })
+        .collect();
+
+    let values = values?;
+
+    Ok(PrimitiveArray::<LT>::from_iter(values))
+}
+
+/// This is similar to `math_checked_op` but just for divide op.
+fn math_checked_divide<LT, RT, F>(
+    left: &PrimitiveArray<LT>,
+    right: &PrimitiveArray<RT>,
+    op: F,
+) -> Result<PrimitiveArray<LT>>
+where
+    LT: ArrowNumericType,
+    RT: ArrowNumericType,
+    RT::Native: One + Zero,
+    F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
+{
+    if left.len() != right.len() {
+        return Err(ArrowError::ComputeError(
+            "Cannot perform math operation on arrays of different 
length".to_string(),
+        ));
+    }
+
+    let left_iter = ArrayIter::new(left);
+    let right_iter = ArrayIter::new(right);
+
+    let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = 
left_iter
+        .into_iter()
+        .zip(right_iter.into_iter())
+        .map(|(l, r)| {
+            if let (Some(l), Some(r)) = (l, r) {
+                let result = op(l, r);
+                if let Some(r) = result {
+                    Ok(Some(r))
+                } else if r.is_zero() {
+                    Err(ArrowError::ComputeError(format!(
+                        "DivideByZero on: {:?}, {:?}",
+                        l, r
+                    )))
+                } else {
+                    // Overflow
+                    Err(ArrowError::ComputeError(format!(
+                        "Overflow happened on: {:?}, {:?}",
+                        l, r
+                    )))
+                }
+            } else {
+                Ok(None)
+            }
+        })
+        .collect();
+
+    let values = values?;
+
+    Ok(PrimitiveArray::<LT>::from_iter(values))
+}
+
 /// Helper function for operations where a valid `0` on the right array should
 /// result in an [ArrowError::DivideByZero], namely the division and modulo 
operations
 ///
@@ -760,15 +861,34 @@ where
 
 /// Perform `left + right` operation on two arrays. If either left or right 
value is null
 /// then the result is also null.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `add_checked` instead.
 pub fn add<T>(
     left: &PrimitiveArray<T>,
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
     T: ArrowNumericType,
-    T::Native: Add<Output = T::Native>,
+    T::Native: ArrowNativeTypeOp,
+{
+    math_op(left, right, |a, b| a.add_wrapping(b))
+}
+
+/// Perform `left + right` operation on two arrays. If either left or right 
value is null
+/// then the result is also null. Once
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `add` instead.
+pub fn add_checked<T>(
+    left: &PrimitiveArray<T>,
+    right: &PrimitiveArray<T>,
+) -> Result<PrimitiveArray<T>>
+where
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
 {
-    math_op(left, right, |a, b| a + b)
+    math_checked_op(left, right, |a, b| a.add_checked(b))
 }
 
 /// Perform `left + right` operation on two arrays. If either left or right 
value is null
@@ -856,15 +976,34 @@ where
 
 /// Perform `left - right` operation on two arrays. If either left or right 
value is null
 /// then the result is also null.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `subtract_checked` instead.
 pub fn subtract<T>(
     left: &PrimitiveArray<T>,
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
     T: datatypes::ArrowNumericType,
-    T::Native: Sub<Output = T::Native>,
+    T::Native: ArrowNativeTypeOp,
 {
-    math_op(left, right, |a, b| a - b)
+    math_op(left, right, |a, b| a.sub_wrapping(b))
+}
+
+/// Perform `left - right` operation on two arrays. If either left or right 
value is null
+/// then the result is also null.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `subtract` instead.
+pub fn subtract_checked<T>(
+    left: &PrimitiveArray<T>,
+    right: &PrimitiveArray<T>,
+) -> Result<PrimitiveArray<T>>
+where
+    T: datatypes::ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
+{
+    math_checked_op(left, right, |a, b| a.sub_checked(b))
 }
 
 /// Perform `left - right` operation on two arrays. If either left or right 
value is null
@@ -933,15 +1072,34 @@ where
 
 /// Perform `left * right` operation on two arrays. If either left or right 
value is null
 /// then the result is also null.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `multiply_check` instead.
 pub fn multiply<T>(
     left: &PrimitiveArray<T>,
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
     T: datatypes::ArrowNumericType,
-    T::Native: Mul<Output = T::Native>,
+    T::Native: ArrowNativeTypeOp,
 {
-    math_op(left, right, |a, b| a * b)
+    math_op(left, right, |a, b| a.mul_wrapping(b))
+}
+
+/// Perform `left * right` operation on two arrays. If either left or right 
value is null
+/// then the result is also null.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `multiply` instead.
+pub fn multiply_checked<T>(
+    left: &PrimitiveArray<T>,
+    right: &PrimitiveArray<T>,
+) -> Result<PrimitiveArray<T>>
+where
+    T: datatypes::ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
+{
+    math_checked_op(left, right, |a, b| a.mul_checked(b))
 }
 
 /// Perform `left * right` operation on two arrays. If either left or right 
value is null
@@ -1013,18 +1171,21 @@ where
 /// Perform `left / right` operation on two arrays. If either left or right 
value is null
 /// then the result is also null. If any right hand value is zero then the 
result of this
 /// operation will be `Err(ArrowError::DivideByZero)`.
-pub fn divide<T>(
+///
+/// When `simd` feature is not enabled. This detects overflow and returns an 
`Err` for that.
+/// For an non-overflow-checking variant, use `divide` instead.
+pub fn divide_checked<T>(
     left: &PrimitiveArray<T>,
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
     T: datatypes::ArrowNumericType,
-    T::Native: Div<Output = T::Native> + Zero + One,
+    T::Native: ArrowNativeTypeOp + Zero + One,
 {
     #[cfg(feature = "simd")]
     return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, 
b| a / b);
     #[cfg(not(feature = "simd"))]
-    return math_checked_divide_op(left, right, |a, b| a / b);
+    return math_checked_divide(left, right, |a, b| a.div_checked(b));
 }
 
 /// Perform `left / right` operation on two arrays. If either left or right 
value is null
@@ -1040,17 +1201,21 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) 
-> Result<ArrayRef> {
 }
 
 /// Perform `left / right` operation on two arrays without checking for 
division by zero.
-/// The result of dividing by zero follows normal floating point rules.
+/// For floating point types, the result of dividing by zero follows normal 
floating point
+/// rules. For other numeric types, dividing by zero will panic,
 /// If either left or right value is null then the result is also null. If any 
right hand value is zero then the result of this
-pub fn divide_unchecked<T>(
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `divide_checked` instead.
+pub fn divide<T>(
     left: &PrimitiveArray<T>,
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowFloatNumericType,
-    T::Native: Div<Output = T::Native>,
+    T: datatypes::ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
 {
-    math_op(left, right, |a, b| a / b)
+    math_op(left, right, |a, b| a.div_wrapping(b))
 }
 
 /// Modulus every value in an array by a scalar. If any value in the array is 
null then the
@@ -1769,7 +1934,7 @@ mod tests {
     fn test_primitive_array_divide_with_nulls() {
         let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), 
Some(9), None]);
         let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), 
None, None]);
-        let c = divide(&a, &b).unwrap();
+        let c = divide_checked(&a, &b).unwrap();
         assert_eq!(3, c.value(0));
         assert!(c.is_null(1));
         assert_eq!(1, c.value(2));
@@ -1854,7 +2019,7 @@ mod tests {
         let b = b.slice(8, 6);
         let b = b.as_any().downcast_ref::<Int32Array>().unwrap();
 
-        let c = divide(a, b).unwrap();
+        let c = divide_checked(a, b).unwrap();
         assert_eq!(6, c.len());
         assert_eq!(3, c.value(0));
         assert!(c.is_null(1));
@@ -1919,6 +2084,14 @@ mod tests {
 
     #[test]
     #[should_panic(expected = "DivideByZero")]
+    fn test_primitive_array_divide_by_zero_with_checked() {
+        let a = Int32Array::from(vec![15]);
+        let b = Int32Array::from(vec![0]);
+        divide_checked(&a, &b).unwrap();
+    }
+
+    #[test]
+    #[should_panic(expected = "attempt to divide by zero")]
     fn test_primitive_array_divide_by_zero() {
         let a = Int32Array::from(vec![15]);
         let b = Int32Array::from(vec![0]);
@@ -2019,4 +2192,57 @@ mod tests {
         let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]);
         assert_eq!(expected, actual);
     }
+
+    #[test]
+    fn test_primitive_add_wrapping_overflow() {
+        let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+        let b = Int32Array::from(vec![1, 1]);
+
+        let wrapped = add(&a, &b);
+        let expected = Int32Array::from(vec![-2147483648, -2147483647]);
+        assert_eq!(expected, wrapped.unwrap());
+
+        let overflow = add_checked(&a, &b);
+        overflow.expect_err("overflow should be detected");
+    }
+
+    #[test]
+    fn test_primitive_subtract_wrapping_overflow() {
+        let a = Int32Array::from(vec![-2]);
+        let b = Int32Array::from(vec![i32::MAX]);
+
+        let wrapped = subtract(&a, &b);
+        let expected = Int32Array::from(vec![i32::MAX]);
+        assert_eq!(expected, wrapped.unwrap());
+
+        let overflow = subtract_checked(&a, &b);
+        overflow.expect_err("overflow should be detected");
+    }
+
+    #[test]
+    fn test_primitive_mul_wrapping_overflow() {
+        let a = Int32Array::from(vec![10]);
+        let b = Int32Array::from(vec![i32::MAX]);
+
+        let wrapped = multiply(&a, &b);
+        let expected = Int32Array::from(vec![-10]);
+        assert_eq!(expected, wrapped.unwrap());
+
+        let overflow = multiply_checked(&a, &b);
+        overflow.expect_err("overflow should be detected");
+    }
+
+    #[test]
+    #[cfg(not(feature = "simd"))]
+    fn test_primitive_div_wrapping_overflow() {
+        let a = Int32Array::from(vec![i32::MIN]);
+        let b = Int32Array::from(vec![-1]);
+
+        let wrapped = divide(&a, &b);
+        let expected = Int32Array::from(vec![-2147483648]);
+        assert_eq!(expected, wrapped.unwrap());
+
+        let overflow = divide_checked(&a, &b);
+        overflow.expect_err("overflow should be detected");
+    }
 }
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index 207e8cb40..444f2b27d 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -114,6 +114,112 @@ pub trait ArrowPrimitiveType: 'static {
     }
 }
 
+pub(crate) mod native_op {
+    use super::ArrowNativeType;
+    use std::ops::{Add, Div, Mul, Sub};
+
+    /// Trait for ArrowNativeType to provide overflow-checking and 
non-overflow-checking
+    /// variants for arithmetic operations. For floating point types, this 
provides some
+    /// default implementations. Integer types that need to deal with overflow 
can implement
+    /// this trait.
+    ///
+    /// The APIs with `_wrapping` suffix are the variant of 
non-overflow-checking. If overflow
+    /// occurred, they will supposedly wrap around the boundary of the type.
+    ///
+    /// The APIs with `_checked` suffix are the variant of overflow-checking 
which return `None`
+    /// if overflow occurred.
+    pub trait ArrowNativeTypeOp:
+        ArrowNativeType
+        + Add<Output = Self>
+        + Sub<Output = Self>
+        + Mul<Output = Self>
+        + Div<Output = Self>
+    {
+        fn add_checked(self, rhs: Self) -> Option<Self> {
+            Some(self + rhs)
+        }
+
+        fn add_wrapping(self, rhs: Self) -> Self {
+            self + rhs
+        }
+
+        fn sub_checked(self, rhs: Self) -> Option<Self> {
+            Some(self - rhs)
+        }
+
+        fn sub_wrapping(self, rhs: Self) -> Self {
+            self - rhs
+        }
+
+        fn mul_checked(self, rhs: Self) -> Option<Self> {
+            Some(self * rhs)
+        }
+
+        fn mul_wrapping(self, rhs: Self) -> Self {
+            self * rhs
+        }
+
+        fn div_checked(self, rhs: Self) -> Option<Self> {
+            Some(self / rhs)
+        }
+
+        fn div_wrapping(self, rhs: Self) -> Self {
+            self / rhs
+        }
+    }
+}
+
+macro_rules! native_type_op {
+    ($t:tt) => {
+        impl native_op::ArrowNativeTypeOp for $t {
+            fn add_checked(self, rhs: Self) -> Option<Self> {
+                self.checked_add(rhs)
+            }
+
+            fn add_wrapping(self, rhs: Self) -> Self {
+                self.wrapping_add(rhs)
+            }
+
+            fn sub_checked(self, rhs: Self) -> Option<Self> {
+                self.checked_sub(rhs)
+            }
+
+            fn sub_wrapping(self, rhs: Self) -> Self {
+                self.wrapping_sub(rhs)
+            }
+
+            fn mul_checked(self, rhs: Self) -> Option<Self> {
+                self.checked_mul(rhs)
+            }
+
+            fn mul_wrapping(self, rhs: Self) -> Self {
+                self.wrapping_mul(rhs)
+            }
+
+            fn div_checked(self, rhs: Self) -> Option<Self> {
+                self.checked_div(rhs)
+            }
+
+            fn div_wrapping(self, rhs: Self) -> Self {
+                self.wrapping_div(rhs)
+            }
+        }
+    };
+}
+
+native_type_op!(i8);
+native_type_op!(i16);
+native_type_op!(i32);
+native_type_op!(i64);
+native_type_op!(u8);
+native_type_op!(u16);
+native_type_op!(u32);
+native_type_op!(u64);
+
+impl native_op::ArrowNativeTypeOp for f16 {}
+impl native_op::ArrowNativeTypeOp for f32 {}
+impl native_op::ArrowNativeTypeOp for f64 {}
+
 impl private::Sealed for i8 {}
 impl ArrowNativeType for i8 {
     #[inline]

Reply via email to