This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new f572ec1be Update `try_binary` and `checked_ops`, and remove
`math_checked_op` (#2717)
f572ec1be is described below
commit f572ec1bef4a66a00b78f1d80a39992d63444ec2
Author: Remzi Yang <[email protected]>
AuthorDate: Fri Sep 16 18:47:20 2022 +0800
Update `try_binary` and `checked_ops`, and remove `math_checked_op` (#2717)
* update try_binary
delete math_checked_op
update the return type of checked ops
Signed-off-by: remzi <[email protected]>
* float div not panic on zero
Signed-off-by: remzi <[email protected]>
* fix nan test
Signed-off-by: remzi <[email protected]>
* add float divide by zero
Signed-off-by: remzi <[email protected]>
* add float tests
Signed-off-by: remzi <[email protected]>
* fix compile error
Signed-off-by: remzi <[email protected]>
Signed-off-by: remzi <[email protected]>
---
arrow/Cargo.toml | 2 +-
arrow/src/compute/kernels/arithmetic.rs | 220 ++++++++++++++------------------
arrow/src/compute/kernels/arity.rs | 14 +-
arrow/src/datatypes/native.rs | 66 +++++++---
4 files changed, 153 insertions(+), 149 deletions(-)
diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml
index e52940b4f..1580856df 100644
--- a/arrow/Cargo.toml
+++ b/arrow/Cargo.toml
@@ -51,7 +51,7 @@ serde_json = { version = "1.0", default-features = false,
features = ["std"], op
indexmap = { version = "1.9", default-features = false, features = ["std"] }
rand = { version = "0.8", default-features = false, features = ["std",
"std_rng"], optional = true }
num = { version = "0.4", default-features = false, features = ["std"] }
-half = { version = "2.0", default-features = false }
+half = { version = "2.0", default-features = false, features = ["num-traits"]}
hashbrown = { version = "0.12", default-features = false }
csv_crate = { version = "1.1", default-features = false, optional = true,
package = "csv" }
regex = { version = "1.5.6", default-features = false, features = ["std",
"unicode"] }
diff --git a/arrow/src/compute/kernels/arithmetic.rs
b/arrow/src/compute/kernels/arithmetic.rs
index 04fe2393e..7b91a261c 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -78,32 +78,6 @@ where
Ok(binary(left, right, op))
}
-/// This is similar to `math_op` as it performs given operation between two
input primitive arrays.
-/// But the given operation can return `None` if overflow is detected. For the
case, this function
-/// returns an `Err`.
-fn math_checked_op<LT, RT, F>(
- left: &PrimitiveArray<LT>,
- right: &PrimitiveArray<RT>,
- op: F,
-) -> Result<PrimitiveArray<LT>>
-where
- LT: ArrowNumericType,
- RT: ArrowNumericType,
- F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
-{
- if left.len() != right.len() {
- return Err(ArrowError::ComputeError(
- "Cannot perform math operation on arrays of different
length".to_string(),
- ));
- }
-
- try_binary(left, right, |a, b| {
- op(a, b).ok_or_else(|| {
- ArrowError::ComputeError(format!("Overflow happened on: {:?},
{:?}", a, b))
- })
- })
-}
-
/// Helper function for operations where a valid `0` on the right array should
/// result in an [ArrowError::DivideByZero], namely the division and modulo
operations
///
@@ -121,26 +95,9 @@ where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
- F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
+ F: Fn(LT::Native, RT::Native) -> Result<LT::Native>,
{
- if left.len() != right.len() {
- return Err(ArrowError::ComputeError(
- "Cannot perform math operation on arrays of different
length".to_string(),
- ));
- }
-
- try_binary(left, right, |l, r| {
- if r.is_zero() {
- Err(ArrowError::DivideByZero)
- } else {
- op(l, r).ok_or_else(|| {
- ArrowError::ComputeError(format!(
- "Overflow happened on: {:?}, {:?}",
- l, r
- ))
- })
- }
- })
+ try_binary(left, right, op)
}
/// Helper function for operations where a valid `0` on the right array should
@@ -161,16 +118,12 @@ fn math_checked_divide_op_on_iters<T, F>(
where
T: ArrowNumericType,
T::Native: One + Zero,
- F: Fn(T::Native, T::Native) -> T::Native,
+ F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
let buffer = if null_bit_buffer.is_some() {
let values = left.zip(right).map(|(left, right)| {
if let (Some(l), Some(r)) = (left, right) {
- if r.is_zero() {
- Err(ArrowError::DivideByZero)
- } else {
- Ok(op(l, r))
- }
+ op(l, r)
} else {
Ok(T::default_value())
}
@@ -179,15 +132,10 @@ where
unsafe { Buffer::try_from_trusted_len_iter(values) }
} else {
// no value is null
- let values = left.map(|l| l.unwrap()).zip(right.map(|r|
r.unwrap())).map(
- |(left, right)| {
- if right.is_zero() {
- Err(ArrowError::DivideByZero)
- } else {
- Ok(op(left, right))
- }
- },
- );
+ let values = left
+ .map(|l| l.unwrap())
+ .zip(right.map(|r| r.unwrap()))
+ .map(|(left, right)| op(left, right));
// Safety: Iterator comes from a PrimitiveArray which reports its size
correctly
unsafe { Buffer::try_from_trusted_len_iter(values) }
}?;
@@ -654,7 +602,7 @@ where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
- F: Fn(T::Native, T::Native) -> T::Native,
+ F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(format!(
@@ -725,7 +673,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
- math_checked_op(left, right, |a, b| a.add_checked(b))
+ try_binary(left, right, |a, b| a.add_checked(b))
}
/// Perform `left + right` operation on two arrays. If either left or right
value is null
@@ -826,11 +774,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
- try_unary(array, |value| {
- value.add_checked(scalar).ok_or_else(|| {
- ArrowError::CastError(format!("Overflow: adding {:?} to {:?}",
scalar, value))
- })
- })
+ try_unary(array, |value| value.add_checked(scalar))
}
/// Add every value in an array by a scalar. If any value in the array is null
then the
@@ -863,12 +807,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
- try_unary_dyn::<_, T>(array, |value| {
- value.add_checked(scalar).ok_or_else(|| {
- ArrowError::CastError(format!("Overflow: adding {:?} to {:?}",
scalar, value))
- })
- })
- .map(|a| Arc::new(a) as ArrayRef)
+ try_unary_dyn::<_, T>(array, |value| value.add_checked(scalar))
+ .map(|a| Arc::new(a) as ArrayRef)
}
/// Perform `left - right` operation on two arrays. If either left or right
value is null
@@ -900,7 +840,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
- math_checked_op(left, right, |a, b| a.sub_checked(b))
+ try_binary(left, right, |a, b| a.sub_checked(b))
}
/// Perform `left - right` operation on two arrays. If either left or right
value is null
@@ -953,14 +893,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
{
- try_unary(array, |value| {
- value.sub_checked(scalar).ok_or_else(|| {
- ArrowError::CastError(format!(
- "Overflow: subtracting {:?} from {:?}",
- scalar, value
- ))
- })
- })
+ try_unary(array, |value| value.sub_checked(scalar))
}
/// Subtract every value in an array by a scalar. If any value in the array is
null then the
@@ -991,15 +924,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
- try_unary_dyn::<_, T>(array, |value| {
- value.sub_checked(scalar).ok_or_else(|| {
- ArrowError::CastError(format!(
- "Overflow: subtracting {:?} from {:?}",
- scalar, value
- ))
- })
- })
- .map(|a| Arc::new(a) as ArrayRef)
+ try_unary_dyn::<_, T>(array, |value| value.sub_checked(scalar))
+ .map(|a| Arc::new(a) as ArrayRef)
}
/// Perform `-` operation on an array. If value is null then the result is
also null.
@@ -1052,7 +978,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
- math_checked_op(left, right, |a, b| a.mul_checked(b))
+ try_binary(left, right, |a, b| a.mul_checked(b))
}
/// Perform `left * right` operation on two arrays. If either left or right
value is null
@@ -1105,14 +1031,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
{
- try_unary(array, |value| {
- value.mul_checked(scalar).ok_or_else(|| {
- ArrowError::CastError(format!(
- "Overflow: multiplying {:?} by {:?}",
- value, scalar,
- ))
- })
- })
+ try_unary(array, |value| value.mul_checked(scalar))
}
/// Multiply every value in an array by a scalar. If any value in the array is
null then the
@@ -1143,15 +1062,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
- try_unary_dyn::<_, T>(array, |value| {
- value.mul_checked(scalar).ok_or_else(|| {
- ArrowError::CastError(format!(
- "Overflow: multiplying {:?} by {:?}",
- value, scalar
- ))
- })
- })
- .map(|a| Arc::new(a) as ArrayRef)
+ try_unary_dyn::<_, T>(array, |value| value.mul_checked(scalar))
+ .map(|a| Arc::new(a) as ArrayRef)
}
/// Perform `left % right` operation on two arrays. If either left or right
value is null
@@ -1170,7 +1082,13 @@ where
a % b
});
#[cfg(not(feature = "simd"))]
- return math_checked_divide_op(left, right, |a, b| Some(a % b));
+ return try_binary(left, right, |a, b| {
+ if b.is_zero() {
+ Err(ArrowError::DivideByZero)
+ } else {
+ Ok(a % b)
+ }
+ });
}
/// Perform `left / right` operation on two arrays. If either left or right
value is null
@@ -1225,12 +1143,17 @@ where
pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
match left.data_type() {
DataType::Dictionary(_, _) => {
- typed_dict_math_op!(left, right, |a, b| a / b,
math_divide_checked_op_dict)
+ typed_dict_math_op!(
+ left,
+ right,
+ |a, b| a.div_checked(b),
+ math_divide_checked_op_dict
+ )
}
_ => {
downcast_primitive_array!(
(left, right) => {
- math_checked_divide_op(left, right, |a, b| Some(a /
b)).map(|a| Arc::new(a) as ArrayRef)
+ math_checked_divide_op(left, right, |a, b|
a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef)
}
_ => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
@@ -1331,15 +1254,8 @@ where
return Err(ArrowError::DivideByZero);
}
- try_unary_dyn::<_, T>(array, |value| {
- value.div_checked(divisor).ok_or_else(|| {
- ArrowError::CastError(format!(
- "Overflow: dividing {:?} by {:?}",
- value, divisor
- ))
- })
- })
- .map(|a| Arc::new(a) as ArrayRef)
+ try_unary_dyn::<_, T>(array, |value| value.div_checked(divisor))
+ .map(|a| Arc::new(a) as ArrayRef)
}
#[cfg(test)]
@@ -2134,23 +2050,41 @@ mod tests {
#[test]
#[should_panic(expected = "DivideByZero")]
- fn test_primitive_array_divide_by_zero_with_checked() {
+ fn test_int_array_divide_by_zero_with_checked() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide_checked(&a, &b).unwrap();
}
+ #[test]
+ #[should_panic(expected = "DivideByZero")]
+ fn test_f32_array_divide_by_zero_with_checked() {
+ let a = Float32Array::from(vec![15.0]);
+ let b = Float32Array::from(vec![0.0]);
+ divide_checked(&a, &b).unwrap();
+ }
+
#[test]
#[should_panic(expected = "attempt to divide by zero")]
- fn test_primitive_array_divide_by_zero() {
+ fn test_int_array_divide_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide(&a, &b).unwrap();
}
+ #[test]
+ fn test_f32_array_divide_by_zero() {
+ let a = Float32Array::from(vec![1.5, 0.0, -1.5]);
+ let b = Float32Array::from(vec![0.0, 0.0, 0.0]);
+ let result = divide(&a, &b).unwrap();
+ assert_eq!(result.value(0), f32::INFINITY);
+ assert!(result.value(1).is_nan());
+ assert_eq!(result.value(2), f32::NEG_INFINITY);
+ }
+
#[test]
#[should_panic(expected = "DivideByZero")]
- fn test_primitive_array_divide_dyn_by_zero() {
+ fn test_int_array_divide_dyn_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide_dyn(&a, &b).unwrap();
@@ -2158,7 +2092,15 @@ mod tests {
#[test]
#[should_panic(expected = "DivideByZero")]
- fn test_primitive_array_divide_dyn_by_zero_dict() {
+ fn test_f32_array_divide_dyn_by_zero() {
+ let a = Float32Array::from(vec![1.5]);
+ let b = Float32Array::from(vec![0.0]);
+ divide_dyn(&a, &b).unwrap();
+ }
+
+ #[test]
+ #[should_panic(expected = "DivideByZero")]
+ fn test_int_array_divide_dyn_by_zero_dict() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::with_capacity(1, 1);
builder.append(15).unwrap();
@@ -2174,14 +2116,38 @@ mod tests {
#[test]
#[should_panic(expected = "DivideByZero")]
- fn test_primitive_array_modulus_by_zero() {
+ fn test_f32_dict_array_divide_dyn_by_zero() {
+ let mut builder =
+ PrimitiveDictionaryBuilder::<Int8Type,
Float32Type>::with_capacity(1, 1);
+ builder.append(1.5).unwrap();
+ let a = builder.finish();
+
+ let mut builder =
+ PrimitiveDictionaryBuilder::<Int8Type,
Float32Type>::with_capacity(1, 1);
+ builder.append(0.0).unwrap();
+ let b = builder.finish();
+
+ divide_dyn(&a, &b).unwrap();
+ }
+
+ #[test]
+ #[should_panic(expected = "DivideByZero")]
+ fn test_i32_array_modulus_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
modulus(&a, &b).unwrap();
}
#[test]
- fn test_primitive_array_divide_f64() {
+ #[should_panic(expected = "DivideByZero")]
+ fn test_f32_array_modulus_by_zero() {
+ let a = Float32Array::from(vec![1.5]);
+ let b = Float32Array::from(vec![0.0]);
+ modulus(&a, &b).unwrap();
+ }
+
+ #[test]
+ fn test_f64_array_divide() {
let a = Float64Array::from(vec![15.0, 15.0, 8.0]);
let b = Float64Array::from(vec![5.0, 6.0, 8.0]);
let c = divide(&a, &b).unwrap();
diff --git a/arrow/src/compute/kernels/arity.rs
b/arrow/src/compute/kernels/arity.rs
index 21c633116..5060234c7 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -261,9 +261,10 @@ where
///
/// Like [`try_unary`] the function is only evaluated for non-null indices
///
-/// # Panic
+/// # Error
///
-/// Panics if the arrays have different lengths
+/// Return an error if the arrays have different lengths or
+/// the operation is under erroneous
pub fn try_binary<A, B, F, O>(
a: &PrimitiveArray<A>,
b: &PrimitiveArray<B>,
@@ -275,13 +276,16 @@ where
O: ArrowPrimitiveType,
F: Fn(A::Native, B::Native) -> Result<O::Native>,
{
- assert_eq!(a.len(), b.len());
- let len = a.len();
-
+ if a.len() != b.len() {
+ return Err(ArrowError::ComputeError(
+ "Cannot perform a binary operation on arrays of different
length".to_string(),
+ ));
+ }
if a.is_empty() {
return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
}
+ let len = a.len();
let null_buffer = combine_option_bitmap(&[a.data(), b.data()],
len).unwrap();
let null_count = null_buffer
.as_ref()
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index de35c4804..dec0cc4b5 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -16,8 +16,10 @@
// under the License.
use super::DataType;
+use crate::error::{ArrowError, Result};
pub use arrow_buffer::{ArrowNativeType, ToByteSlice};
use half::f16;
+use num::Zero;
/// Trait bridging the dynamic-typed nature of Arrow (via [`DataType`]) with
the
/// static-typed nature of rust types ([`ArrowNativeType`]) for all types that
implement [`ArrowNativeType`].
@@ -43,6 +45,8 @@ pub trait ArrowPrimitiveType: 'static {
pub(crate) mod native_op {
use super::ArrowNativeType;
+ use crate::error::{ArrowError, Result};
+ use num::Zero;
use std::ops::{Add, Div, Mul, Sub};
/// Trait for ArrowNativeType to provide overflow-checking and
non-overflow-checking
@@ -61,33 +65,38 @@ pub(crate) mod native_op {
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ + Zero
{
- fn add_checked(self, rhs: Self) -> Option<Self> {
- Some(self + rhs)
+ fn add_checked(self, rhs: Self) -> Result<Self> {
+ Ok(self + rhs)
}
fn add_wrapping(self, rhs: Self) -> Self {
self + rhs
}
- fn sub_checked(self, rhs: Self) -> Option<Self> {
- Some(self - rhs)
+ fn sub_checked(self, rhs: Self) -> Result<Self> {
+ Ok(self - rhs)
}
fn sub_wrapping(self, rhs: Self) -> Self {
self - rhs
}
- fn mul_checked(self, rhs: Self) -> Option<Self> {
- Some(self * rhs)
+ fn mul_checked(self, rhs: Self) -> Result<Self> {
+ Ok(self * rhs)
}
fn mul_wrapping(self, rhs: Self) -> Self {
self * rhs
}
- fn div_checked(self, rhs: Self) -> Option<Self> {
- Some(self / rhs)
+ fn div_checked(self, rhs: Self) -> Result<Self> {
+ if rhs.is_zero() {
+ Err(ArrowError::DivideByZero)
+ } else {
+ Ok(self / rhs)
+ }
}
fn div_wrapping(self, rhs: Self) -> Self {
@@ -99,32 +108,56 @@ pub(crate) mod native_op {
macro_rules! native_type_op {
($t:tt) => {
impl native_op::ArrowNativeTypeOp for $t {
- fn add_checked(self, rhs: Self) -> Option<Self> {
- self.checked_add(rhs)
+ fn add_checked(self, rhs: Self) -> Result<Self> {
+ self.checked_add(rhs).ok_or_else(|| {
+ ArrowError::ComputeError(format!(
+ "Overflow happened on: {:?} + {:?}",
+ self, rhs
+ ))
+ })
}
fn add_wrapping(self, rhs: Self) -> Self {
self.wrapping_add(rhs)
}
- fn sub_checked(self, rhs: Self) -> Option<Self> {
- self.checked_sub(rhs)
+ fn sub_checked(self, rhs: Self) -> Result<Self> {
+ self.checked_sub(rhs).ok_or_else(|| {
+ ArrowError::ComputeError(format!(
+ "Overflow happened on: {:?} - {:?}",
+ self, rhs
+ ))
+ })
}
fn sub_wrapping(self, rhs: Self) -> Self {
self.wrapping_sub(rhs)
}
- fn mul_checked(self, rhs: Self) -> Option<Self> {
- self.checked_mul(rhs)
+ fn mul_checked(self, rhs: Self) -> Result<Self> {
+ self.checked_mul(rhs).ok_or_else(|| {
+ ArrowError::ComputeError(format!(
+ "Overflow happened on: {:?} * {:?}",
+ self, rhs
+ ))
+ })
}
fn mul_wrapping(self, rhs: Self) -> Self {
self.wrapping_mul(rhs)
}
- fn div_checked(self, rhs: Self) -> Option<Self> {
- self.checked_div(rhs)
+ fn div_checked(self, rhs: Self) -> Result<Self> {
+ if rhs.is_zero() {
+ Err(ArrowError::DivideByZero)
+ } else {
+ self.checked_div(rhs).ok_or_else(|| {
+ ArrowError::ComputeError(format!(
+ "Overflow happened on: {:?} / {:?}",
+ self, rhs
+ ))
+ })
+ }
}
fn div_wrapping(self, rhs: Self) -> Self {
@@ -138,6 +171,7 @@ native_type_op!(i8);
native_type_op!(i16);
native_type_op!(i32);
native_type_op!(i64);
+native_type_op!(i128);
native_type_op!(u8);
native_type_op!(u16);
native_type_op!(u32);