This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 7594db636 Add overflow-checking variants of arithmetic scalar dyn 
kernels (#2713)
7594db636 is described below

commit 7594db6367515473efdb130e7de91060079a4d88
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Sep 14 16:54:23 2022 -0700

    Add overflow-checking variants of arithmetic scalar dyn kernels (#2713)
    
    * Add overflow-checking variants of arithmetic scalar dyn kernels
    
    * Update doc
    
    * For review
---
 arrow/src/compute/kernels/arithmetic.rs | 199 ++++++++++++++++++++++++++++----
 arrow/src/compute/kernels/arity.rs      |  50 +++++++-
 2 files changed, 226 insertions(+), 23 deletions(-)

diff --git a/arrow/src/compute/kernels/arithmetic.rs 
b/arrow/src/compute/kernels/arithmetic.rs
index a344407e4..04fe2393e 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -22,7 +22,7 @@
 //! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
 //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
 
-use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
+use std::ops::{Div, Neg, Rem};
 
 use num::{One, Zero};
 
@@ -32,7 +32,9 @@ use crate::buffer::Buffer;
 use crate::buffer::MutableBuffer;
 use crate::compute::kernels::arity::unary;
 use crate::compute::util::combine_option_bitmap;
-use crate::compute::{binary, binary_opt, try_binary, try_unary, unary_dyn};
+use crate::compute::{
+    binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
+};
 use crate::datatypes::{
     native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, 
Date64Type,
     IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, 
IntervalYearMonthType,
@@ -834,12 +836,39 @@ where
 /// Add every value in an array by a scalar. If any value in the array is null 
then the
 /// result is also null. The given array must be a `PrimitiveArray` of the 
type same as
 /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
+///
+/// This returns an `Err` when the input array is not supported for adding 
operation.
 pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> 
Result<ArrayRef>
 where
     T: ArrowNumericType,
-    T::Native: Add<Output = T::Native>,
+    T::Native: ArrowNativeTypeOp,
+{
+    unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar))
+}
+
+/// Add every value in an array by a scalar. If any value in the array is null 
then the
+/// result is also null. The given array must be a `PrimitiveArray` of the 
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `add_scalar_dyn` instead.
+///
+/// As this kernel has the branching costs and also prevents LLVM from 
vectorising it correctly,
+/// it is usually much slower than non-checking variant.
+pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) -> 
Result<ArrayRef>
+where
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
 {
-    unary_dyn::<_, T>(array, |value| value + scalar)
+    try_unary_dyn::<_, T>(array, |value| {
+        value.add_checked(scalar).ok_or_else(|| {
+            ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", 
scalar, value))
+        })
+    })
+    .map(|a| Arc::new(a) as ArrayRef)
 }
 
 /// Perform `left - right` operation on two arrays. If either left or right 
value is null
@@ -937,16 +966,40 @@ where
 /// Subtract every value in an array by a scalar. If any value in the array is 
null then the
 /// result is also null. The given array must be a `PrimitiveArray` of the 
type same as
 /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` 
instead.
 pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> 
Result<ArrayRef>
 where
-    T: datatypes::ArrowNumericType,
-    T::Native: Add<Output = T::Native>
-        + Sub<Output = T::Native>
-        + Mul<Output = T::Native>
-        + Div<Output = T::Native>
-        + Zero,
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
+{
+    unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar))
+}
+
+/// Subtract every value in an array by a scalar. If any value in the array is 
null then the
+/// result is also null. The given array must be a `PrimitiveArray` of the 
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `subtract_scalar_dyn` instead.
+pub fn subtract_scalar_checked_dyn<T>(
+    array: &dyn Array,
+    scalar: T::Native,
+) -> Result<ArrayRef>
+where
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
 {
-    unary_dyn::<_, T>(array, |value| value - scalar)
+    try_unary_dyn::<_, T>(array, |value| {
+        value.sub_checked(scalar).ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Overflow: subtracting {:?} from {:?}",
+                scalar, value
+            ))
+        })
+    })
+    .map(|a| Arc::new(a) as ArrayRef)
 }
 
 /// Perform `-` operation on an array. If value is null then the result is 
also null.
@@ -1065,18 +1118,40 @@ where
 /// Multiply every value in an array by a scalar. If any value in the array is 
null then the
 /// result is also null. The given array must be a `PrimitiveArray` of the 
type same as
 /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` 
instead.
 pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> 
Result<ArrayRef>
 where
     T: ArrowNumericType,
-    T::Native: Add<Output = T::Native>
-        + Sub<Output = T::Native>
-        + Mul<Output = T::Native>
-        + Div<Output = T::Native>
-        + Rem<Output = T::Native>
-        + Zero
-        + One,
+    T::Native: ArrowNativeTypeOp,
+{
+    unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar))
+}
+
+/// Subtract every value in an array by a scalar. If any value in the array is 
null then the
+/// result is also null. The given array must be a `PrimitiveArray` of the 
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `multiply_scalar_dyn` instead.
+pub fn multiply_scalar_checked_dyn<T>(
+    array: &dyn Array,
+    scalar: T::Native,
+) -> Result<ArrayRef>
+where
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
 {
-    unary_dyn::<_, T>(array, |value| value * scalar)
+    try_unary_dyn::<_, T>(array, |value| {
+        value.mul_checked(scalar).ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Overflow: multiplying {:?} by {:?}",
+                value, scalar
+            ))
+        })
+    })
+    .map(|a| Arc::new(a) as ArrayRef)
 }
 
 /// Perform `left % right` operation on two arrays. If either left or right 
value is null
@@ -1223,15 +1298,48 @@ where
 /// result is also null. If the scalar is zero then the result of this 
operation will be
 /// `Err(ArrowError::DivideByZero)`. The given array must be a 
`PrimitiveArray` of the type
 /// same as the scalar, or a `DictionaryArray` of the value type same as the 
scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
 pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> 
Result<ArrayRef>
 where
     T: ArrowNumericType,
-    T::Native: Div<Output = T::Native> + Zero,
+    T::Native: ArrowNativeTypeOp + Zero,
+{
+    if divisor.is_zero() {
+        return Err(ArrowError::DivideByZero);
+    }
+    unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor))
+}
+
+/// Divide every value in an array by a scalar. If any value in the array is 
null then the
+/// result is also null. If the scalar is zero then the result of this 
operation will be
+/// `Err(ArrowError::DivideByZero)`. The given array must be a 
`PrimitiveArray` of the type
+/// same as the scalar, or a `DictionaryArray` of the value type same as the 
scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `divide_scalar_dyn` instead.
+pub fn divide_scalar_checked_dyn<T>(
+    array: &dyn Array,
+    divisor: T::Native,
+) -> Result<ArrayRef>
+where
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp + Zero,
 {
     if divisor.is_zero() {
         return Err(ArrowError::DivideByZero);
     }
-    unary_dyn::<_, T>(array, |value| value / divisor)
+
+    try_unary_dyn::<_, T>(array, |value| {
+        value.div_checked(divisor).ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Overflow: dividing {:?} by {:?}",
+                value, divisor
+            ))
+        })
+    })
+    .map(|a| Arc::new(a) as ArrayRef)
 }
 
 #[cfg(test)]
@@ -2222,6 +2330,55 @@ mod tests {
         overflow.expect_err("overflow should be detected");
     }
 
+    #[test]
+    fn test_primitive_add_scalar_dyn_wrapping_overflow() {
+        let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+
+        let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap();
+        let expected =
+            Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as 
ArrayRef;
+        assert_eq!(&expected, &wrapped);
+
+        let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1);
+        overflow.expect_err("overflow should be detected");
+    }
+
+    #[test]
+    fn test_primitive_subtract_scalar_dyn_wrapping_overflow() {
+        let a = Int32Array::from(vec![-2]);
+
+        let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
+        let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef;
+        assert_eq!(&expected, &wrapped);
+
+        let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
+        overflow.expect_err("overflow should be detected");
+    }
+
+    #[test]
+    fn test_primitive_mul_scalar_dyn_wrapping_overflow() {
+        let a = Int32Array::from(vec![10]);
+
+        let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
+        let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef;
+        assert_eq!(&expected, &wrapped);
+
+        let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
+        overflow.expect_err("overflow should be detected");
+    }
+
+    #[test]
+    fn test_primitive_div_scalar_dyn_wrapping_overflow() {
+        let a = Int32Array::from(vec![i32::MIN]);
+
+        let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap();
+        let expected = Arc::new(Int32Array::from(vec![-2147483648])) as 
ArrayRef;
+        assert_eq!(&expected, &wrapped);
+
+        let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1);
+        overflow.expect_err("overflow should be detected");
+    }
+
     #[test]
     fn test_primitive_div_opt_overflow_division_by_zero() {
         let a = Int32Array::from(vec![i32::MIN]);
diff --git a/arrow/src/compute/kernels/arity.rs 
b/arrow/src/compute/kernels/arity.rs
index fffa81af8..21c633116 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -123,7 +123,7 @@ where
     Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, 
null_buffer) })
 }
 
-/// A helper function that applies an unary function to a dictionary array 
with primitive value type.
+/// A helper function that applies an infallible unary function to a 
dictionary array with primitive value type.
 fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
 where
     K: ArrowNumericType,
@@ -138,7 +138,22 @@ where
     Ok(Arc::new(new_dict))
 }
 
-/// Applies an unary function to an array with primitive values.
+/// A helper function that applies a fallible unary function to a dictionary 
array with primitive value type.
+fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> 
Result<ArrayRef>
+where
+    K: ArrowNumericType,
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native) -> Result<T::Native>,
+{
+    let dict_values = array.values().as_any().downcast_ref().unwrap();
+    let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
+    let data = array.data().clone().into_builder().child_data(vec![values]);
+
+    let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() 
}.into();
+    Ok(Arc::new(new_dict))
+}
+
+/// Applies an infallible unary function to an array with primitive values.
 pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
 where
     T: ArrowPrimitiveType,
@@ -162,6 +177,37 @@ where
     }
 }
 
+/// Applies a fallible unary function to an array with primitive values.
+pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
+where
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native) -> Result<T::Native>,
+{
+    downcast_dictionary_array! {
+        array => if array.values().data_type() == &T::DATA_TYPE {
+            try_unary_dict::<_, F, T>(array, op)
+        } else {
+            Err(ArrowError::NotYetImplemented(format!(
+                "Cannot perform unary operation on dictionary array of type 
{}",
+                array.data_type()
+            )))
+        },
+        t => {
+            if t == &T::DATA_TYPE {
+                Ok(Arc::new(try_unary::<T, F, T>(
+                    
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
+                    op,
+                )?))
+            } else {
+                Err(ArrowError::NotYetImplemented(format!(
+                    "Cannot perform unary operation on array of type {}",
+                    t
+                )))
+            }
+        }
+    }
+}
+
 /// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in 
`0..len`, collecting
 /// the results in a [`PrimitiveArray`]. If any index is null in either `a` or 
`b`, the
 /// corresponding index in the result will also be null

Reply via email to