This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new e79ba40a6 Add overflow-checking variant of sum kernel (#2822)
e79ba40a6 is described below
commit e79ba40a6fa6f6e20211eee51c529332c6ed6f96
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Oct 5 10:51:58 2022 -0700
Add overflow-checking variant of sum kernel (#2822)
* Define overflow-checking behavior of sum kernels
* Add sum_checked.
* Add sum_array_checked.
---
arrow/src/compute/kernels/aggregate.rs | 129 +++++++++++++++++++++++++++++++--
1 file changed, 122 insertions(+), 7 deletions(-)
diff --git a/arrow/src/compute/kernels/aggregate.rs
b/arrow/src/compute/kernels/aggregate.rs
index c215e2395..083defdde 100644
--- a/arrow/src/compute/kernels/aggregate.rs
+++ b/arrow/src/compute/kernels/aggregate.rs
@@ -17,14 +17,19 @@
//! Defines aggregations over Arrow arrays.
+use arrow_data::bit_iterator::try_for_each_valid_idx;
+use arrow_schema::ArrowError;
use multiversion::multiversion;
-use std::ops::Add;
+#[allow(unused_imports)]
+use std::ops::{Add, Deref};
use crate::array::{
as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
+use crate::datatypes::native_op::ArrowNativeTypeOp;
use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
+use crate::error::Result;
use crate::util::bit_iterator::BitIndexIterator;
/// Generic test for NaN, the optimizer should be able to remove this for
integer types.
@@ -162,10 +167,13 @@ pub fn min_string<T: OffsetSizeTrait>(array:
&GenericStringArray<T>) -> Option<&
}
/// Returns the sum of values in the array.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `sum_array_checked` instead.
pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) ->
Option<T::Native>
where
T: ArrowNumericType,
- T::Native: Add<Output = T::Native>,
+ T::Native: ArrowNativeTypeOp,
{
match array.data_type() {
DataType::Dictionary(_, _) => {
@@ -180,7 +188,7 @@ where
.into_iter()
.fold(T::default_value(), |accumulator, value| {
if let Some(value) = value {
- accumulator + value
+ accumulator.add_wrapping(value)
} else {
accumulator
}
@@ -192,6 +200,42 @@ where
}
}
+/// Returns the sum of values in the array.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `sum_array` instead.
+pub fn sum_array_checked<T, A: ArrayAccessor<Item = T::Native>>(
+ array: A,
+) -> Result<Option<T::Native>>
+where
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
+{
+ match array.data_type() {
+ DataType::Dictionary(_, _) => {
+ let null_count = array.null_count();
+
+ if null_count == array.len() {
+ return Ok(None);
+ }
+
+ let iter = ArrayIter::new(array);
+ let sum =
+ iter.into_iter()
+ .try_fold(T::default_value(), |accumulator, value| {
+ if let Some(value) = value {
+ accumulator.add_checked(value)
+ } else {
+ Ok(accumulator)
+ }
+ })?;
+
+ Ok(Some(sum))
+ }
+ _ => sum_checked::<T>(as_primitive_array(&array)),
+ }
+}
+
/// Returns the min of values in the array of `ArrowNumericType` type, or
dictionary
/// array with value of `ArrowNumericType` type.
pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) ->
Option<T::Native>
@@ -239,11 +283,14 @@ where
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `sum_checked` instead.
#[cfg(not(feature = "simd"))]
pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T: ArrowNumericType,
- T::Native: Add<Output = T::Native>,
+ T::Native: ArrowNativeTypeOp,
{
let null_count = array.null_count();
@@ -256,7 +303,7 @@ where
match array.data().null_buffer() {
None => {
let sum = data.iter().fold(T::default_value(), |accumulator,
value| {
- accumulator + *value
+ accumulator.add_wrapping(*value)
});
Some(sum)
@@ -274,7 +321,7 @@ where
let mut index_mask = 1;
chunk.iter().for_each(|value| {
if (mask & index_mask) != 0 {
- sum = sum + *value;
+ sum = sum.add_wrapping(*value);
}
index_mask <<= 1;
});
@@ -284,7 +331,7 @@ where
remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
- sum = sum + *value;
+ sum = sum.add_wrapping(*value);
}
});
@@ -293,6 +340,54 @@ where
}
}
+/// Returns the sum of values in the primitive array.
+///
+/// Returns `Ok(None)` if the array is empty or only contains null values.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `sum` instead.
+pub fn sum_checked<T>(array: &PrimitiveArray<T>) -> Result<Option<T::Native>>
+where
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
+{
+ let null_count = array.null_count();
+
+ if null_count == array.len() {
+ return Ok(None);
+ }
+
+ let data: &[T::Native] = array.values();
+
+ match array.data().null_buffer() {
+ None => {
+ let sum = data
+ .iter()
+ .try_fold(T::default_value(), |accumulator, value| {
+ accumulator.add_checked(*value)
+ })?;
+
+ Ok(Some(sum))
+ }
+ Some(buffer) => {
+ let mut sum = T::default_value();
+
+ try_for_each_valid_idx(
+ array.len(),
+ array.offset(),
+ null_count,
+ Some(buffer.deref()),
+ |idx| {
+ unsafe { sum =
sum.add_checked(array.value_unchecked(idx))? };
+ Ok::<_, ArrowError>(())
+ },
+ )?;
+
+ Ok(Some(sum))
+ }
+ }
+}
+
#[cfg(feature = "simd")]
mod simd {
use super::is_nan;
@@ -638,6 +733,9 @@ mod simd {
/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
+///
+/// This doesn't detect overflow in release mode by default. Once overflowing,
the result will
+/// wrap around. For an overflow-checking variant, use `sum_checked` instead.
#[cfg(feature = "simd")]
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
@@ -1216,4 +1314,21 @@ mod tests {
let actual = max_binary(sliced_input);
assert_eq!(actual, expected);
}
+
+ #[test]
+ #[cfg(not(feature = "simd"))]
+ fn test_sum_overflow() {
+ let a = Int32Array::from(vec![i32::MAX, 1]);
+
+ assert_eq!(sum(&a).unwrap(), -2147483648);
+ assert_eq!(sum_array::<Int32Type, _>(&a).unwrap(), -2147483648);
+ }
+
+ #[test]
+ fn test_sum_checked_overflow() {
+ let a = Int32Array::from(vec![i32::MAX, 1]);
+
+ sum_checked(&a).expect_err("overflow should be detected");
+ sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be
detected");
+ }
}