This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e79ba40a6 Add overflow-checking variant of sum kernel (#2822)
e79ba40a6 is described below

commit e79ba40a6fa6f6e20211eee51c529332c6ed6f96
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Oct 5 10:51:58 2022 -0700

    Add overflow-checking variant of sum kernel (#2822)
    
    * Define overflow-checking behavior of sum kernels
    
    * Add sum_checked.
    
    * Add sum_array_checked.
---
 arrow/src/compute/kernels/aggregate.rs | 129 +++++++++++++++++++++++++++++++--
 1 file changed, 122 insertions(+), 7 deletions(-)

diff --git a/arrow/src/compute/kernels/aggregate.rs 
b/arrow/src/compute/kernels/aggregate.rs
index c215e2395..083defdde 100644
--- a/arrow/src/compute/kernels/aggregate.rs
+++ b/arrow/src/compute/kernels/aggregate.rs
@@ -17,14 +17,19 @@
 
 //! Defines aggregations over Arrow arrays.
 
+use arrow_data::bit_iterator::try_for_each_valid_idx;
+use arrow_schema::ArrowError;
 use multiversion::multiversion;
-use std::ops::Add;
+#[allow(unused_imports)]
+use std::ops::{Add, Deref};
 
 use crate::array::{
     as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
     GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
 };
+use crate::datatypes::native_op::ArrowNativeTypeOp;
 use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
+use crate::error::Result;
 use crate::util::bit_iterator::BitIndexIterator;
 
 /// Generic test for NaN, the optimizer should be able to remove this for 
integer types.
@@ -162,10 +167,13 @@ pub fn min_string<T: OffsetSizeTrait>(array: 
&GenericStringArray<T>) -> Option<&
 }
 
 /// Returns the sum of values in the array.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `sum_array_checked` instead.
 pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> 
Option<T::Native>
 where
     T: ArrowNumericType,
-    T::Native: Add<Output = T::Native>,
+    T::Native: ArrowNativeTypeOp,
 {
     match array.data_type() {
         DataType::Dictionary(_, _) => {
@@ -180,7 +188,7 @@ where
                 .into_iter()
                 .fold(T::default_value(), |accumulator, value| {
                     if let Some(value) = value {
-                        accumulator + value
+                        accumulator.add_wrapping(value)
                     } else {
                         accumulator
                     }
@@ -192,6 +200,42 @@ where
     }
 }
 
+/// Returns the sum of values in the array.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `sum_array` instead.
+pub fn sum_array_checked<T, A: ArrayAccessor<Item = T::Native>>(
+    array: A,
+) -> Result<Option<T::Native>>
+where
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
+{
+    match array.data_type() {
+        DataType::Dictionary(_, _) => {
+            let null_count = array.null_count();
+
+            if null_count == array.len() {
+                return Ok(None);
+            }
+
+            let iter = ArrayIter::new(array);
+            let sum =
+                iter.into_iter()
+                    .try_fold(T::default_value(), |accumulator, value| {
+                        if let Some(value) = value {
+                            accumulator.add_checked(value)
+                        } else {
+                            Ok(accumulator)
+                        }
+                    })?;
+
+            Ok(Some(sum))
+        }
+        _ => sum_checked::<T>(as_primitive_array(&array)),
+    }
+}
+
 /// Returns the min of values in the array of `ArrowNumericType` type, or 
dictionary
 /// array with value of `ArrowNumericType` type.
 pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> 
Option<T::Native>
@@ -239,11 +283,14 @@ where
 /// Returns the sum of values in the primitive array.
 ///
 /// Returns `None` if the array is empty or only contains null values.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
+/// For an overflow-checking variant, use `sum_checked` instead.
 #[cfg(not(feature = "simd"))]
 pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
 where
     T: ArrowNumericType,
-    T::Native: Add<Output = T::Native>,
+    T::Native: ArrowNativeTypeOp,
 {
     let null_count = array.null_count();
 
@@ -256,7 +303,7 @@ where
     match array.data().null_buffer() {
         None => {
             let sum = data.iter().fold(T::default_value(), |accumulator, 
value| {
-                accumulator + *value
+                accumulator.add_wrapping(*value)
             });
 
             Some(sum)
@@ -274,7 +321,7 @@ where
                     let mut index_mask = 1;
                     chunk.iter().for_each(|value| {
                         if (mask & index_mask) != 0 {
-                            sum = sum + *value;
+                            sum = sum.add_wrapping(*value);
                         }
                         index_mask <<= 1;
                     });
@@ -284,7 +331,7 @@ where
 
             remainder.iter().enumerate().for_each(|(i, value)| {
                 if remainder_bits & (1 << i) != 0 {
-                    sum = sum + *value;
+                    sum = sum.add_wrapping(*value);
                 }
             });
 
@@ -293,6 +340,54 @@ where
     }
 }
 
+/// Returns the sum of values in the primitive array.
+///
+/// Returns `Ok(None)` if the array is empty or only contains null values.
+///
+/// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
+/// use `sum` instead.
+pub fn sum_checked<T>(array: &PrimitiveArray<T>) -> Result<Option<T::Native>>
+where
+    T: ArrowNumericType,
+    T::Native: ArrowNativeTypeOp,
+{
+    let null_count = array.null_count();
+
+    if null_count == array.len() {
+        return Ok(None);
+    }
+
+    let data: &[T::Native] = array.values();
+
+    match array.data().null_buffer() {
+        None => {
+            let sum = data
+                .iter()
+                .try_fold(T::default_value(), |accumulator, value| {
+                    accumulator.add_checked(*value)
+                })?;
+
+            Ok(Some(sum))
+        }
+        Some(buffer) => {
+            let mut sum = T::default_value();
+
+            try_for_each_valid_idx(
+                array.len(),
+                array.offset(),
+                null_count,
+                Some(buffer.deref()),
+                |idx| {
+                    unsafe { sum = 
sum.add_checked(array.value_unchecked(idx))? };
+                    Ok::<_, ArrowError>(())
+                },
+            )?;
+
+            Ok(Some(sum))
+        }
+    }
+}
+
 #[cfg(feature = "simd")]
 mod simd {
     use super::is_nan;
@@ -638,6 +733,9 @@ mod simd {
 /// Returns the sum of values in the primitive array.
 ///
 /// Returns `None` if the array is empty or only contains null values.
+///
+/// This doesn't detect overflow in release mode by default. Once overflowing, 
the result will
+/// wrap around. For an overflow-checking variant, use `sum_checked` instead.
 #[cfg(feature = "simd")]
 pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
 where
@@ -1216,4 +1314,21 @@ mod tests {
         let actual = max_binary(sliced_input);
         assert_eq!(actual, expected);
     }
+
+    #[test]
+    #[cfg(not(feature = "simd"))]
+    fn test_sum_overflow() {
+        let a = Int32Array::from(vec![i32::MAX, 1]);
+
+        assert_eq!(sum(&a).unwrap(), -2147483648);
+        assert_eq!(sum_array::<Int32Type, _>(&a).unwrap(), -2147483648);
+    }
+
+    #[test]
+    fn test_sum_checked_overflow() {
+        let a = Int32Array::from(vec![i32::MAX, 1]);
+
+        sum_checked(&a).expect_err("overflow should be detected");
+        sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be 
detected");
+    }
 }

Reply via email to