This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 7594db636 Add overflow-checking variants of arithmetic scalar dyn
kernels (#2713)
7594db636 is described below
commit 7594db6367515473efdb130e7de91060079a4d88
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Sep 14 16:54:23 2022 -0700
Add overflow-checking variants of arithmetic scalar dyn kernels (#2713)
* Add overflow-checking variants of arithmetic scalar dyn kernels
* Update doc
* For review
---
arrow/src/compute/kernels/arithmetic.rs | 199 ++++++++++++++++++++++++++++----
arrow/src/compute/kernels/arity.rs | 50 +++++++-
2 files changed, 226 insertions(+), 23 deletions(-)
diff --git a/arrow/src/compute/kernels/arithmetic.rs
b/arrow/src/compute/kernels/arithmetic.rs
index a344407e4..04fe2393e 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -22,7 +22,7 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
-use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
+use std::ops::{Div, Neg, Rem};
use num::{One, Zero};
@@ -32,7 +32,9 @@ use crate::buffer::Buffer;
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::util::combine_option_bitmap;
-use crate::compute::{binary, binary_opt, try_binary, try_unary, unary_dyn};
+use crate::compute::{
+ binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
+};
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type,
Date64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType,
@@ -834,12 +836,39 @@ where
/// Add every value in an array by a scalar. If any value in the array is null
then the
/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
+///
+/// This returns an `Err` when the input array is not supported for adding
operation.
pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) ->
Result<ArrayRef>
where
T: ArrowNumericType,
- T::Native: Add<Output = T::Native>,
+ T::Native: ArrowNativeTypeOp,
+{
+ unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar))
+}
+
+/// Add every value in an array by a scalar. If any value in the array is null
then the
+/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `add_scalar_dyn` instead.
+///
+/// As this kernel has the branching costs and also prevents LLVM from
vectorising it correctly,
+/// it is usually much slower than non-checking variant.
+pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) ->
Result<ArrayRef>
+where
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
- unary_dyn::<_, T>(array, |value| value + scalar)
+ try_unary_dyn::<_, T>(array, |value| {
+ value.add_checked(scalar).ok_or_else(|| {
+ ArrowError::CastError(format!("Overflow: adding {:?} to {:?}",
scalar, value))
+ })
+ })
+ .map(|a| Arc::new(a) as ArrayRef)
}
/// Perform `left - right` operation on two arrays. If either left or right
value is null
@@ -937,16 +966,40 @@ where
/// Subtract every value in an array by a scalar. If any value in the array is
null then the
/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `subtract_scalar_checked_dyn`
instead.
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) ->
Result<ArrayRef>
where
- T: datatypes::ArrowNumericType,
- T::Native: Add<Output = T::Native>
- + Sub<Output = T::Native>
- + Mul<Output = T::Native>
- + Div<Output = T::Native>
- + Zero,
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
+{
+ unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar))
+}
+
+/// Subtract every value in an array by a scalar. If any value in the array is
null then the
+/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `subtract_scalar_dyn` instead.
+pub fn subtract_scalar_checked_dyn<T>(
+ array: &dyn Array,
+ scalar: T::Native,
+) -> Result<ArrayRef>
+where
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
- unary_dyn::<_, T>(array, |value| value - scalar)
+ try_unary_dyn::<_, T>(array, |value| {
+ value.sub_checked(scalar).ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Overflow: subtracting {:?} from {:?}",
+ scalar, value
+ ))
+ })
+ })
+ .map(|a| Arc::new(a) as ArrayRef)
}
/// Perform `-` operation on an array. If value is null then the result is
also null.
@@ -1065,18 +1118,40 @@ where
/// Multiply every value in an array by a scalar. If any value in the array is
null then the
/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `multiply_scalar_checked_dyn`
instead.
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) ->
Result<ArrayRef>
where
T: ArrowNumericType,
- T::Native: Add<Output = T::Native>
- + Sub<Output = T::Native>
- + Mul<Output = T::Native>
- + Div<Output = T::Native>
- + Rem<Output = T::Native>
- + Zero
- + One,
+ T::Native: ArrowNativeTypeOp,
+{
+ unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar))
+}
+
+/// Subtract every value in an array by a scalar. If any value in the array is
null then the
+/// result is also null. The given array must be a `PrimitiveArray` of the
type same as
+/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `multiply_scalar_dyn` instead.
+pub fn multiply_scalar_checked_dyn<T>(
+ array: &dyn Array,
+ scalar: T::Native,
+) -> Result<ArrayRef>
+where
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp,
{
- unary_dyn::<_, T>(array, |value| value * scalar)
+ try_unary_dyn::<_, T>(array, |value| {
+ value.mul_checked(scalar).ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Overflow: multiplying {:?} by {:?}",
+ value, scalar
+ ))
+ })
+ })
+ .map(|a| Arc::new(a) as ArrayRef)
}
/// Perform `left % right` operation on two arrays. If either left or right
value is null
@@ -1223,15 +1298,48 @@ where
/// result is also null. If the scalar is zero then the result of this
operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a
`PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the
scalar.
+///
+/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
+/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) ->
Result<ArrayRef>
where
T: ArrowNumericType,
- T::Native: Div<Output = T::Native> + Zero,
+ T::Native: ArrowNativeTypeOp + Zero,
+{
+ if divisor.is_zero() {
+ return Err(ArrowError::DivideByZero);
+ }
+ unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor))
+}
+
+/// Divide every value in an array by a scalar. If any value in the array is
null then the
+/// result is also null. If the scalar is zero then the result of this
operation will be
+/// `Err(ArrowError::DivideByZero)`. The given array must be a
`PrimitiveArray` of the type
+/// same as the scalar, or a `DictionaryArray` of the value type same as the
scalar.
+///
+/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
+/// use `divide_scalar_dyn` instead.
+pub fn divide_scalar_checked_dyn<T>(
+ array: &dyn Array,
+ divisor: T::Native,
+) -> Result<ArrayRef>
+where
+ T: ArrowNumericType,
+ T::Native: ArrowNativeTypeOp + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
- unary_dyn::<_, T>(array, |value| value / divisor)
+
+ try_unary_dyn::<_, T>(array, |value| {
+ value.div_checked(divisor).ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Overflow: dividing {:?} by {:?}",
+ value, divisor
+ ))
+ })
+ })
+ .map(|a| Arc::new(a) as ArrayRef)
}
#[cfg(test)]
@@ -2222,6 +2330,55 @@ mod tests {
overflow.expect_err("overflow should be detected");
}
+ #[test]
+ fn test_primitive_add_scalar_dyn_wrapping_overflow() {
+ let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
+
+ let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap();
+ let expected =
+ Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as
ArrayRef;
+ assert_eq!(&expected, &wrapped);
+
+ let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1);
+ overflow.expect_err("overflow should be detected");
+ }
+
+ #[test]
+ fn test_primitive_subtract_scalar_dyn_wrapping_overflow() {
+ let a = Int32Array::from(vec![-2]);
+
+ let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
+ let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef;
+ assert_eq!(&expected, &wrapped);
+
+ let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
+ overflow.expect_err("overflow should be detected");
+ }
+
+ #[test]
+ fn test_primitive_mul_scalar_dyn_wrapping_overflow() {
+ let a = Int32Array::from(vec![10]);
+
+ let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
+ let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef;
+ assert_eq!(&expected, &wrapped);
+
+ let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
+ overflow.expect_err("overflow should be detected");
+ }
+
+ #[test]
+ fn test_primitive_div_scalar_dyn_wrapping_overflow() {
+ let a = Int32Array::from(vec![i32::MIN]);
+
+ let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap();
+ let expected = Arc::new(Int32Array::from(vec![-2147483648])) as
ArrayRef;
+ assert_eq!(&expected, &wrapped);
+
+ let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1);
+ overflow.expect_err("overflow should be detected");
+ }
+
#[test]
fn test_primitive_div_opt_overflow_division_by_zero() {
let a = Int32Array::from(vec![i32::MIN]);
diff --git a/arrow/src/compute/kernels/arity.rs
b/arrow/src/compute/kernels/arity.rs
index fffa81af8..21c633116 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -123,7 +123,7 @@ where
Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count,
null_buffer) })
}
-/// A helper function that applies an unary function to a dictionary array
with primitive value type.
+/// A helper function that applies an infallible unary function to a
dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
@@ -138,7 +138,22 @@ where
Ok(Arc::new(new_dict))
}
-/// Applies an unary function to an array with primitive values.
+/// A helper function that applies a fallible unary function to a dictionary
array with primitive value type.
+fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) ->
Result<ArrayRef>
+where
+ K: ArrowNumericType,
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native) -> Result<T::Native>,
+{
+ let dict_values = array.values().as_any().downcast_ref().unwrap();
+ let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
+ let data = array.data().clone().into_builder().child_data(vec![values]);
+
+ let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked()
}.into();
+ Ok(Arc::new(new_dict))
+}
+
+/// Applies an infallible unary function to an array with primitive values.
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
@@ -162,6 +177,37 @@ where
}
}
+/// Applies a fallible unary function to an array with primitive values.
+pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
+where
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native) -> Result<T::Native>,
+{
+ downcast_dictionary_array! {
+ array => if array.values().data_type() == &T::DATA_TYPE {
+ try_unary_dict::<_, F, T>(array, op)
+ } else {
+ Err(ArrowError::NotYetImplemented(format!(
+ "Cannot perform unary operation on dictionary array of type
{}",
+ array.data_type()
+ )))
+ },
+ t => {
+ if t == &T::DATA_TYPE {
+ Ok(Arc::new(try_unary::<T, F, T>(
+
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
+ op,
+ )?))
+ } else {
+ Err(ArrowError::NotYetImplemented(format!(
+ "Cannot perform unary operation on array of type {}",
+ t
+ )))
+ }
+ }
+ }
+}
+
/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in
`0..len`, collecting
/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or
`b`, the
/// corresponding index in the result will also be null