This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new f572ec1be Update `try_binary` and `checked_ops`, and remove 
`math_checked_op` (#2717)
f572ec1be is described below

commit f572ec1bef4a66a00b78f1d80a39992d63444ec2
Author: Remzi Yang <[email protected]>
AuthorDate: Fri Sep 16 18:47:20 2022 +0800

    Update `try_binary` and `checked_ops`, and remove `math_checked_op` (#2717)
    
    * update try_binary
    delete math_checked_op
    update the return type of checked ops
    
    Signed-off-by: remzi <[email protected]>
    
    * float div not panic on zero
    
    Signed-off-by: remzi <[email protected]>
    
    * fix nan test
    
    Signed-off-by: remzi <[email protected]>
    
    * add float divide by zero
    
    Signed-off-by: remzi <[email protected]>
    
    * add float tests
    
    Signed-off-by: remzi <[email protected]>
    
    * fix compile error
    
    Signed-off-by: remzi <[email protected]>
    
    Signed-off-by: remzi <[email protected]>
---
 arrow/Cargo.toml                        |   2 +-
 arrow/src/compute/kernels/arithmetic.rs | 220 ++++++++++++++------------------
 arrow/src/compute/kernels/arity.rs      |  14 +-
 arrow/src/datatypes/native.rs           |  66 +++++++---
 4 files changed, 153 insertions(+), 149 deletions(-)

diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml
index e52940b4f..1580856df 100644
--- a/arrow/Cargo.toml
+++ b/arrow/Cargo.toml
@@ -51,7 +51,7 @@ serde_json = { version = "1.0", default-features = false, 
features = ["std"], op
 indexmap = { version = "1.9", default-features = false, features = ["std"] }
 rand = { version = "0.8", default-features = false, features = ["std", 
"std_rng"], optional = true }
 num = { version = "0.4", default-features = false, features = ["std"] }
-half = { version = "2.0", default-features = false }
+half = { version = "2.0", default-features = false, features = ["num-traits"]}
 hashbrown = { version = "0.12", default-features = false }
 csv_crate = { version = "1.1", default-features = false, optional = true, 
package = "csv" }
 regex = { version = "1.5.6", default-features = false, features = ["std", 
"unicode"] }
diff --git a/arrow/src/compute/kernels/arithmetic.rs 
b/arrow/src/compute/kernels/arithmetic.rs
index 04fe2393e..7b91a261c 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -78,32 +78,6 @@ where
     Ok(binary(left, right, op))
 }
 
-/// This is similar to `math_op` as it performs given operation between two 
input primitive arrays.
-/// But the given operation can return `None` if overflow is detected. For the 
case, this function
-/// returns an `Err`.
-fn math_checked_op<LT, RT, F>(
-    left: &PrimitiveArray<LT>,
-    right: &PrimitiveArray<RT>,
-    op: F,
-) -> Result<PrimitiveArray<LT>>
-where
-    LT: ArrowNumericType,
-    RT: ArrowNumericType,
-    F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
-{
-    if left.len() != right.len() {
-        return Err(ArrowError::ComputeError(
-            "Cannot perform math operation on arrays of different 
length".to_string(),
-        ));
-    }
-
-    try_binary(left, right, |a, b| {
-        op(a, b).ok_or_else(|| {
-            ArrowError::ComputeError(format!("Overflow happened on: {:?}, 
{:?}", a, b))
-        })
-    })
-}
-
 /// Helper function for operations where a valid `0` on the right array should
 /// result in an [ArrowError::DivideByZero], namely the division and modulo 
operations
 ///
@@ -121,26 +95,9 @@ where
     LT: ArrowNumericType,
     RT: ArrowNumericType,
     RT::Native: One + Zero,
-    F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
+    F: Fn(LT::Native, RT::Native) -> Result<LT::Native>,
 {
-    if left.len() != right.len() {
-        return Err(ArrowError::ComputeError(
-            "Cannot perform math operation on arrays of different 
length".to_string(),
-        ));
-    }
-
-    try_binary(left, right, |l, r| {
-        if r.is_zero() {
-            Err(ArrowError::DivideByZero)
-        } else {
-            op(l, r).ok_or_else(|| {
-                ArrowError::ComputeError(format!(
-                    "Overflow happened on: {:?}, {:?}",
-                    l, r
-                ))
-            })
-        }
-    })
+    try_binary(left, right, op)
 }
 
 /// Helper function for operations where a valid `0` on the right array should
@@ -161,16 +118,12 @@ fn math_checked_divide_op_on_iters<T, F>(
 where
     T: ArrowNumericType,
     T::Native: One + Zero,
-    F: Fn(T::Native, T::Native) -> T::Native,
+    F: Fn(T::Native, T::Native) -> Result<T::Native>,
 {
     let buffer = if null_bit_buffer.is_some() {
         let values = left.zip(right).map(|(left, right)| {
             if let (Some(l), Some(r)) = (left, right) {
-                if r.is_zero() {
-                    Err(ArrowError::DivideByZero)
-                } else {
-                    Ok(op(l, r))
-                }
+                op(l, r)
             } else {
                 Ok(T::default_value())
             }
@@ -179,15 +132,10 @@ where
         unsafe { Buffer::try_from_trusted_len_iter(values) }
     } else {
         // no value is null
-        let values = left.map(|l| l.unwrap()).zip(right.map(|r| 
r.unwrap())).map(
-            |(left, right)| {
-                if right.is_zero() {
-                    Err(ArrowError::DivideByZero)
-                } else {
-                    Ok(op(left, right))
-                }
-            },
-        );
+        let values = left
+            .map(|l| l.unwrap())
+            .zip(right.map(|r| r.unwrap()))
+            .map(|(left, right)| op(left, right));
         // Safety: Iterator comes from a PrimitiveArray which reports its size 
correctly
         unsafe { Buffer::try_from_trusted_len_iter(values) }
     }?;
@@ -654,7 +602,7 @@ where
     K: ArrowNumericType,
     T: ArrowNumericType,
     T::Native: One + Zero,
-    F: Fn(T::Native, T::Native) -> T::Native,
+    F: Fn(T::Native, T::Native) -> Result<T::Native>,
 {
     if left.len() != right.len() {
         return Err(ArrowError::ComputeError(format!(
@@ -725,7 +673,7 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
-    math_checked_op(left, right, |a, b| a.add_checked(b))
+    try_binary(left, right, |a, b| a.add_checked(b))
 }
 
 /// Perform `left + right` operation on two arrays. If either left or right 
value is null
@@ -826,11 +774,7 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
-    try_unary(array, |value| {
-        value.add_checked(scalar).ok_or_else(|| {
-            ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", 
scalar, value))
-        })
-    })
+    try_unary(array, |value| value.add_checked(scalar))
 }
 
 /// Add every value in an array by a scalar. If any value in the array is null 
then the
@@ -863,12 +807,8 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
-    try_unary_dyn::<_, T>(array, |value| {
-        value.add_checked(scalar).ok_or_else(|| {
-            ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", 
scalar, value))
-        })
-    })
-    .map(|a| Arc::new(a) as ArrayRef)
+    try_unary_dyn::<_, T>(array, |value| value.add_checked(scalar))
+        .map(|a| Arc::new(a) as ArrayRef)
 }
 
 /// Perform `left - right` operation on two arrays. If either left or right 
value is null
@@ -900,7 +840,7 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
-    math_checked_op(left, right, |a, b| a.sub_checked(b))
+    try_binary(left, right, |a, b| a.sub_checked(b))
 }
 
 /// Perform `left - right` operation on two arrays. If either left or right 
value is null
@@ -953,14 +893,7 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp + Zero,
 {
-    try_unary(array, |value| {
-        value.sub_checked(scalar).ok_or_else(|| {
-            ArrowError::CastError(format!(
-                "Overflow: subtracting {:?} from {:?}",
-                scalar, value
-            ))
-        })
-    })
+    try_unary(array, |value| value.sub_checked(scalar))
 }
 
 /// Subtract every value in an array by a scalar. If any value in the array is 
null then the
@@ -991,15 +924,8 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
-    try_unary_dyn::<_, T>(array, |value| {
-        value.sub_checked(scalar).ok_or_else(|| {
-            ArrowError::CastError(format!(
-                "Overflow: subtracting {:?} from {:?}",
-                scalar, value
-            ))
-        })
-    })
-    .map(|a| Arc::new(a) as ArrayRef)
+    try_unary_dyn::<_, T>(array, |value| value.sub_checked(scalar))
+        .map(|a| Arc::new(a) as ArrayRef)
 }
 
 /// Perform `-` operation on an array. If value is null then the result is 
also null.
@@ -1052,7 +978,7 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
-    math_checked_op(left, right, |a, b| a.mul_checked(b))
+    try_binary(left, right, |a, b| a.mul_checked(b))
 }
 
 /// Perform `left * right` operation on two arrays. If either left or right 
value is null
@@ -1105,14 +1031,7 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp + Zero + One,
 {
-    try_unary(array, |value| {
-        value.mul_checked(scalar).ok_or_else(|| {
-            ArrowError::CastError(format!(
-                "Overflow: multiplying {:?} by {:?}",
-                value, scalar,
-            ))
-        })
-    })
+    try_unary(array, |value| value.mul_checked(scalar))
 }
 
 /// Multiply every value in an array by a scalar. If any value in the array is 
null then the
@@ -1143,15 +1062,8 @@ where
     T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
-    try_unary_dyn::<_, T>(array, |value| {
-        value.mul_checked(scalar).ok_or_else(|| {
-            ArrowError::CastError(format!(
-                "Overflow: multiplying {:?} by {:?}",
-                value, scalar
-            ))
-        })
-    })
-    .map(|a| Arc::new(a) as ArrayRef)
+    try_unary_dyn::<_, T>(array, |value| value.mul_checked(scalar))
+        .map(|a| Arc::new(a) as ArrayRef)
 }
 
 /// Perform `left % right` operation on two arrays. If either left or right 
value is null
@@ -1170,7 +1082,13 @@ where
         a % b
     });
     #[cfg(not(feature = "simd"))]
-    return math_checked_divide_op(left, right, |a, b| Some(a % b));
+    return try_binary(left, right, |a, b| {
+        if b.is_zero() {
+            Err(ArrowError::DivideByZero)
+        } else {
+            Ok(a % b)
+        }
+    });
 }
 
 /// Perform `left / right` operation on two arrays. If either left or right 
value is null
@@ -1225,12 +1143,17 @@ where
 pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
     match left.data_type() {
         DataType::Dictionary(_, _) => {
-            typed_dict_math_op!(left, right, |a, b| a / b, 
math_divide_checked_op_dict)
+            typed_dict_math_op!(
+                left,
+                right,
+                |a, b| a.div_checked(b),
+                math_divide_checked_op_dict
+            )
         }
         _ => {
             downcast_primitive_array!(
                 (left, right) => {
-                    math_checked_divide_op(left, right, |a, b| Some(a / 
b)).map(|a| Arc::new(a) as ArrayRef)
+                    math_checked_divide_op(left, right, |a, b| 
a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef)
                 }
                 _ => Err(ArrowError::CastError(format!(
                     "Unsupported data type {}, {}",
@@ -1331,15 +1254,8 @@ where
         return Err(ArrowError::DivideByZero);
     }
 
-    try_unary_dyn::<_, T>(array, |value| {
-        value.div_checked(divisor).ok_or_else(|| {
-            ArrowError::CastError(format!(
-                "Overflow: dividing {:?} by {:?}",
-                value, divisor
-            ))
-        })
-    })
-    .map(|a| Arc::new(a) as ArrayRef)
+    try_unary_dyn::<_, T>(array, |value| value.div_checked(divisor))
+        .map(|a| Arc::new(a) as ArrayRef)
 }
 
 #[cfg(test)]
@@ -2134,23 +2050,41 @@ mod tests {
 
     #[test]
     #[should_panic(expected = "DivideByZero")]
-    fn test_primitive_array_divide_by_zero_with_checked() {
+    fn test_int_array_divide_by_zero_with_checked() {
         let a = Int32Array::from(vec![15]);
         let b = Int32Array::from(vec![0]);
         divide_checked(&a, &b).unwrap();
     }
 
+    #[test]
+    #[should_panic(expected = "DivideByZero")]
+    fn test_f32_array_divide_by_zero_with_checked() {
+        let a = Float32Array::from(vec![15.0]);
+        let b = Float32Array::from(vec![0.0]);
+        divide_checked(&a, &b).unwrap();
+    }
+
     #[test]
     #[should_panic(expected = "attempt to divide by zero")]
-    fn test_primitive_array_divide_by_zero() {
+    fn test_int_array_divide_by_zero() {
         let a = Int32Array::from(vec![15]);
         let b = Int32Array::from(vec![0]);
         divide(&a, &b).unwrap();
     }
 
+    #[test]
+    fn test_f32_array_divide_by_zero() {
+        let a = Float32Array::from(vec![1.5, 0.0, -1.5]);
+        let b = Float32Array::from(vec![0.0, 0.0, 0.0]);
+        let result = divide(&a, &b).unwrap();
+        assert_eq!(result.value(0), f32::INFINITY);
+        assert!(result.value(1).is_nan());
+        assert_eq!(result.value(2), f32::NEG_INFINITY);
+    }
+
     #[test]
     #[should_panic(expected = "DivideByZero")]
-    fn test_primitive_array_divide_dyn_by_zero() {
+    fn test_int_array_divide_dyn_by_zero() {
         let a = Int32Array::from(vec![15]);
         let b = Int32Array::from(vec![0]);
         divide_dyn(&a, &b).unwrap();
@@ -2158,7 +2092,15 @@ mod tests {
 
     #[test]
     #[should_panic(expected = "DivideByZero")]
-    fn test_primitive_array_divide_dyn_by_zero_dict() {
+    fn test_f32_array_divide_dyn_by_zero() {
+        let a = Float32Array::from(vec![1.5]);
+        let b = Float32Array::from(vec![0.0]);
+        divide_dyn(&a, &b).unwrap();
+    }
+
+    #[test]
+    #[should_panic(expected = "DivideByZero")]
+    fn test_int_array_divide_dyn_by_zero_dict() {
         let mut builder =
             PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::with_capacity(1, 1);
         builder.append(15).unwrap();
@@ -2174,14 +2116,38 @@ mod tests {
 
     #[test]
     #[should_panic(expected = "DivideByZero")]
-    fn test_primitive_array_modulus_by_zero() {
+    fn test_f32_dict_array_divide_dyn_by_zero() {
+        let mut builder =
+            PrimitiveDictionaryBuilder::<Int8Type, 
Float32Type>::with_capacity(1, 1);
+        builder.append(1.5).unwrap();
+        let a = builder.finish();
+
+        let mut builder =
+            PrimitiveDictionaryBuilder::<Int8Type, 
Float32Type>::with_capacity(1, 1);
+        builder.append(0.0).unwrap();
+        let b = builder.finish();
+
+        divide_dyn(&a, &b).unwrap();
+    }
+
+    #[test]
+    #[should_panic(expected = "DivideByZero")]
+    fn test_i32_array_modulus_by_zero() {
         let a = Int32Array::from(vec![15]);
         let b = Int32Array::from(vec![0]);
         modulus(&a, &b).unwrap();
     }
 
     #[test]
-    fn test_primitive_array_divide_f64() {
+    #[should_panic(expected = "DivideByZero")]
+    fn test_f32_array_modulus_by_zero() {
+        let a = Float32Array::from(vec![1.5]);
+        let b = Float32Array::from(vec![0.0]);
+        modulus(&a, &b).unwrap();
+    }
+
+    #[test]
+    fn test_f64_array_divide() {
         let a = Float64Array::from(vec![15.0, 15.0, 8.0]);
         let b = Float64Array::from(vec![5.0, 6.0, 8.0]);
         let c = divide(&a, &b).unwrap();
diff --git a/arrow/src/compute/kernels/arity.rs 
b/arrow/src/compute/kernels/arity.rs
index 21c633116..5060234c7 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -261,9 +261,10 @@ where
 ///
 /// Like [`try_unary`] the function is only evaluated for non-null indices
 ///
-/// # Panic
+/// # Error
 ///
-/// Panics if the arrays have different lengths
+/// Return an error if the arrays have different lengths or
+/// the operation is under erroneous
 pub fn try_binary<A, B, F, O>(
     a: &PrimitiveArray<A>,
     b: &PrimitiveArray<B>,
@@ -275,13 +276,16 @@ where
     O: ArrowPrimitiveType,
     F: Fn(A::Native, B::Native) -> Result<O::Native>,
 {
-    assert_eq!(a.len(), b.len());
-    let len = a.len();
-
+    if a.len() != b.len() {
+        return Err(ArrowError::ComputeError(
+            "Cannot perform a binary operation on arrays of different 
length".to_string(),
+        ));
+    }
     if a.is_empty() {
         return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
     }
 
+    let len = a.len();
     let null_buffer = combine_option_bitmap(&[a.data(), b.data()], 
len).unwrap();
     let null_count = null_buffer
         .as_ref()
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index de35c4804..dec0cc4b5 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -16,8 +16,10 @@
 // under the License.
 
 use super::DataType;
+use crate::error::{ArrowError, Result};
 pub use arrow_buffer::{ArrowNativeType, ToByteSlice};
 use half::f16;
+use num::Zero;
 
 /// Trait bridging the dynamic-typed nature of Arrow (via [`DataType`]) with 
the
 /// static-typed nature of rust types ([`ArrowNativeType`]) for all types that 
implement [`ArrowNativeType`].
@@ -43,6 +45,8 @@ pub trait ArrowPrimitiveType: 'static {
 
 pub(crate) mod native_op {
     use super::ArrowNativeType;
+    use crate::error::{ArrowError, Result};
+    use num::Zero;
     use std::ops::{Add, Div, Mul, Sub};
 
     /// Trait for ArrowNativeType to provide overflow-checking and 
non-overflow-checking
@@ -61,33 +65,38 @@ pub(crate) mod native_op {
         + Sub<Output = Self>
         + Mul<Output = Self>
         + Div<Output = Self>
+        + Zero
     {
-        fn add_checked(self, rhs: Self) -> Option<Self> {
-            Some(self + rhs)
+        fn add_checked(self, rhs: Self) -> Result<Self> {
+            Ok(self + rhs)
         }
 
         fn add_wrapping(self, rhs: Self) -> Self {
             self + rhs
         }
 
-        fn sub_checked(self, rhs: Self) -> Option<Self> {
-            Some(self - rhs)
+        fn sub_checked(self, rhs: Self) -> Result<Self> {
+            Ok(self - rhs)
         }
 
         fn sub_wrapping(self, rhs: Self) -> Self {
             self - rhs
         }
 
-        fn mul_checked(self, rhs: Self) -> Option<Self> {
-            Some(self * rhs)
+        fn mul_checked(self, rhs: Self) -> Result<Self> {
+            Ok(self * rhs)
         }
 
         fn mul_wrapping(self, rhs: Self) -> Self {
             self * rhs
         }
 
-        fn div_checked(self, rhs: Self) -> Option<Self> {
-            Some(self / rhs)
+        fn div_checked(self, rhs: Self) -> Result<Self> {
+            if rhs.is_zero() {
+                Err(ArrowError::DivideByZero)
+            } else {
+                Ok(self / rhs)
+            }
         }
 
         fn div_wrapping(self, rhs: Self) -> Self {
@@ -99,32 +108,56 @@ pub(crate) mod native_op {
 macro_rules! native_type_op {
     ($t:tt) => {
         impl native_op::ArrowNativeTypeOp for $t {
-            fn add_checked(self, rhs: Self) -> Option<Self> {
-                self.checked_add(rhs)
+            fn add_checked(self, rhs: Self) -> Result<Self> {
+                self.checked_add(rhs).ok_or_else(|| {
+                    ArrowError::ComputeError(format!(
+                        "Overflow happened on: {:?} + {:?}",
+                        self, rhs
+                    ))
+                })
             }
 
             fn add_wrapping(self, rhs: Self) -> Self {
                 self.wrapping_add(rhs)
             }
 
-            fn sub_checked(self, rhs: Self) -> Option<Self> {
-                self.checked_sub(rhs)
+            fn sub_checked(self, rhs: Self) -> Result<Self> {
+                self.checked_sub(rhs).ok_or_else(|| {
+                    ArrowError::ComputeError(format!(
+                        "Overflow happened on: {:?} - {:?}",
+                        self, rhs
+                    ))
+                })
             }
 
             fn sub_wrapping(self, rhs: Self) -> Self {
                 self.wrapping_sub(rhs)
             }
 
-            fn mul_checked(self, rhs: Self) -> Option<Self> {
-                self.checked_mul(rhs)
+            fn mul_checked(self, rhs: Self) -> Result<Self> {
+                self.checked_mul(rhs).ok_or_else(|| {
+                    ArrowError::ComputeError(format!(
+                        "Overflow happened on: {:?} * {:?}",
+                        self, rhs
+                    ))
+                })
             }
 
             fn mul_wrapping(self, rhs: Self) -> Self {
                 self.wrapping_mul(rhs)
             }
 
-            fn div_checked(self, rhs: Self) -> Option<Self> {
-                self.checked_div(rhs)
+            fn div_checked(self, rhs: Self) -> Result<Self> {
+                if rhs.is_zero() {
+                    Err(ArrowError::DivideByZero)
+                } else {
+                    self.checked_div(rhs).ok_or_else(|| {
+                        ArrowError::ComputeError(format!(
+                            "Overflow happened on: {:?} / {:?}",
+                            self, rhs
+                        ))
+                    })
+                }
             }
 
             fn div_wrapping(self, rhs: Self) -> Self {
@@ -138,6 +171,7 @@ native_type_op!(i8);
 native_type_op!(i16);
 native_type_op!(i32);
 native_type_op!(i64);
+native_type_op!(i128);
 native_type_op!(u8);
 native_type_op!(u16);
 native_type_op!(u32);

Reply via email to