This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 491b0239a Fix unary_dyn for decimal scalar arithmetic computation 
(#3345)
491b0239a is described below

commit 491b0239a81bb3e7e2829d69c5a59799a0d4f6e6
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Dec 18 14:15:31 2022 -0800

    Fix unary_dyn for decimal scalar arithmetic computation (#3345)
    
    * Fix unary for decimal arithmetic computation
    
    * Use discriminant
---
 arrow/src/compute/kernels/arithmetic.rs | 20 +++++++++++++++++++-
 arrow/src/compute/kernels/arity.rs      | 17 +++++++++++------
 2 files changed, 30 insertions(+), 7 deletions(-)

diff --git a/arrow/src/compute/kernels/arithmetic.rs 
b/arrow/src/compute/kernels/arithmetic.rs
index 23cefe48e..913a2cad6 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -1633,7 +1633,7 @@ mod tests {
     use super::*;
     use crate::array::Int32Array;
     use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut};
-    use crate::datatypes::{Date64Type, Int32Type, Int8Type};
+    use crate::datatypes::{Date64Type, Decimal128Type, Int32Type, Int8Type};
     use arrow_buffer::i256;
     use chrono::NaiveDate;
     use half::f16;
@@ -3226,4 +3226,22 @@ mod tests {
         ])) as ArrayRef;
         assert_eq!(&result, &expected);
     }
+
+    #[test]
+    fn test_decimal_add_scalar_dyn() {
+        let a = Decimal128Array::from(vec![100, 210, 320])
+            .with_precision_and_scale(38, 2)
+            .unwrap();
+
+        let result = add_scalar_dyn::<Decimal128Type>(&a, 1).unwrap();
+        let result = as_primitive_array::<Decimal128Type>(&result)
+            .clone()
+            .with_precision_and_scale(38, 2)
+            .unwrap();
+        let expected = Decimal128Array::from(vec![101, 211, 321])
+            .with_precision_and_scale(38, 2)
+            .unwrap();
+
+        assert_eq!(&expected, &result);
+    }
 }
diff --git a/arrow/src/compute/kernels/arity.rs 
b/arrow/src/compute/kernels/arity.rs
index 6207ab639..02659a5a7 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -114,9 +114,12 @@ where
     T: ArrowPrimitiveType,
     F: Fn(T::Native) -> Result<T::Native>,
 {
-    if array.value_type() != T::DATA_TYPE {
+    if std::mem::discriminant(&array.value_type())
+        != std::mem::discriminant(&T::DATA_TYPE)
+    {
         return Err(ArrowError::CastError(format!(
-            "Cannot perform the unary operation on dictionary array of value 
type {}",
+            "Cannot perform the unary operation of type {} on dictionary array 
of value type {}",
+            T::DATA_TYPE,
             array.value_type()
         )));
     }
@@ -135,14 +138,15 @@ where
     downcast_dictionary_array! {
         array => unary_dict::<_, F, T>(array, op),
         t => {
-            if t == &T::DATA_TYPE {
+            if std::mem::discriminant(t) == 
std::mem::discriminant(&T::DATA_TYPE) {
                 Ok(Arc::new(unary::<T, F, T>(
                     
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
                     op,
                 )))
             } else {
                 Err(ArrowError::NotYetImplemented(format!(
-                    "Cannot perform unary operation on array of type {}",
+                    "Cannot perform unary operation of type {} on array of 
type {}",
+                    T::DATA_TYPE,
                     t
                 )))
             }
@@ -166,14 +170,15 @@ where
             )))
         },
         t => {
-            if t == &T::DATA_TYPE {
+            if std::mem::discriminant(t) == 
std::mem::discriminant(&T::DATA_TYPE) {
                 Ok(Arc::new(try_unary::<T, F, T>(
                     
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
                     op,
                 )?))
             } else {
                 Err(ArrowError::NotYetImplemented(format!(
-                    "Cannot perform unary operation on array of type {}",
+                    "Cannot perform unary operation of type {} on array of 
type {}",
+                    T::DATA_TYPE,
                     t
                 )))
             }

Reply via email to