This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 491b0239a Fix unary_dyn for decimal scalar arithmetic computation
(#3345)
491b0239a is described below
commit 491b0239a81bb3e7e2829d69c5a59799a0d4f6e6
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Dec 18 14:15:31 2022 -0800
Fix unary_dyn for decimal scalar arithmetic computation (#3345)
* Fix unary for decimal arithmetic computation
* Use discriminant
---
arrow/src/compute/kernels/arithmetic.rs | 20 +++++++++++++++++++-
arrow/src/compute/kernels/arity.rs | 17 +++++++++++------
2 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/arrow/src/compute/kernels/arithmetic.rs
b/arrow/src/compute/kernels/arithmetic.rs
index 23cefe48e..913a2cad6 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -1633,7 +1633,7 @@ mod tests {
use super::*;
use crate::array::Int32Array;
use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut};
- use crate::datatypes::{Date64Type, Int32Type, Int8Type};
+ use crate::datatypes::{Date64Type, Decimal128Type, Int32Type, Int8Type};
use arrow_buffer::i256;
use chrono::NaiveDate;
use half::f16;
@@ -3226,4 +3226,22 @@ mod tests {
])) as ArrayRef;
assert_eq!(&result, &expected);
}
+
+ #[test]
+ fn test_decimal_add_scalar_dyn() {
+ let a = Decimal128Array::from(vec![100, 210, 320])
+ .with_precision_and_scale(38, 2)
+ .unwrap();
+
+ let result = add_scalar_dyn::<Decimal128Type>(&a, 1).unwrap();
+ let result = as_primitive_array::<Decimal128Type>(&result)
+ .clone()
+ .with_precision_and_scale(38, 2)
+ .unwrap();
+ let expected = Decimal128Array::from(vec![101, 211, 321])
+ .with_precision_and_scale(38, 2)
+ .unwrap();
+
+ assert_eq!(&expected, &result);
+ }
}
diff --git a/arrow/src/compute/kernels/arity.rs
b/arrow/src/compute/kernels/arity.rs
index 6207ab639..02659a5a7 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -114,9 +114,12 @@ where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
- if array.value_type() != T::DATA_TYPE {
+ if std::mem::discriminant(&array.value_type())
+ != std::mem::discriminant(&T::DATA_TYPE)
+ {
return Err(ArrowError::CastError(format!(
- "Cannot perform the unary operation on dictionary array of value
type {}",
+ "Cannot perform the unary operation of type {} on dictionary array
of value type {}",
+ T::DATA_TYPE,
array.value_type()
)));
}
@@ -135,14 +138,15 @@ where
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
- if t == &T::DATA_TYPE {
+ if std::mem::discriminant(t) ==
std::mem::discriminant(&T::DATA_TYPE) {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)))
} else {
Err(ArrowError::NotYetImplemented(format!(
- "Cannot perform unary operation on array of type {}",
+ "Cannot perform unary operation of type {} on array of
type {}",
+ T::DATA_TYPE,
t
)))
}
@@ -166,14 +170,15 @@ where
)))
},
t => {
- if t == &T::DATA_TYPE {
+ if std::mem::discriminant(t) ==
std::mem::discriminant(&T::DATA_TYPE) {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
- "Cannot perform unary operation on array of type {}",
+ "Cannot perform unary operation of type {} on array of
type {}",
+ T::DATA_TYPE,
t
)))
}