This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 20af94b0a Add negate kernels (#4488) (#4494)
20af94b0a is described below
commit 20af94b0acf8632e6512fad04b92e0602275d6ee
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Mon Jul 10 09:59:08 2023 -0400
Add negate kernels (#4488) (#4494)
* Add negate kernels (#4488)
* Fix doc
* Add Inteval tests
* Review feedback
---
arrow-arith/src/arithmetic.rs | 2 +
arrow-arith/src/numeric.rs | 236 ++++++++++++++++++++++++++++++++++++++++++
arrow-array/src/types.rs | 6 ++
3 files changed, 244 insertions(+)
diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs
index 4f6ecc78d..4566afc2e 100644
--- a/arrow-arith/src/arithmetic.rs
+++ b/arrow-arith/src/arithmetic.rs
@@ -510,6 +510,7 @@ pub fn subtract_scalar_checked_dyn<T: ArrowNumericType>(
///
/// This doesn't detect overflow. Once overflowing, the result will wrap
around.
/// For an overflow-checking variant, use `negate_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::neg_wrapping")]
pub fn negate<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
@@ -520,6 +521,7 @@ pub fn negate<T: ArrowNumericType>(
///
/// This detects overflow and returns an `Err` for that. For an
non-overflow-checking variant,
/// use `negate` instead.
+#[deprecated(note = "Use arrow_arith::numeric::neg")]
pub fn negate_checked<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowError> {
diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs
index 816fcaa94..c2e867dc9 100644
--- a/arrow-arith/src/numeric.rs
+++ b/arrow-arith/src/numeric.rs
@@ -74,6 +74,97 @@ pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) ->
Result<ArrayRef, ArrowError> {
arithmetic_op(Op::Rem, lhs, rhs)
}
+macro_rules! neg_checked {
+ ($t:ty, $a:ident) => {{
+ let array = $a
+ .as_primitive::<$t>()
+ .try_unary::<_, $t, _>(|x| x.neg_checked())?;
+ Ok(Arc::new(array))
+ }};
+}
+
+macro_rules! neg_wrapping {
+ ($t:ty, $a:ident) => {{
+ let array = $a.as_primitive::<$t>().unary::<_, $t>(|x|
x.neg_wrapping());
+ Ok(Arc::new(array))
+ }};
+}
+
+/// Negates each element of `array`, returning an error on overflow
+///
+/// Note: negation of unsigned arrays is not supported and will return in an
error,
+/// for wrapping unsigned negation consider using
[`neg_wrapping`][neg_wrapping()]
+pub fn neg(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+ use DataType::*;
+ use IntervalUnit::*;
+ use TimeUnit::*;
+
+ match array.data_type() {
+ Int8 => neg_checked!(Int8Type, array),
+ Int16 => neg_checked!(Int16Type, array),
+ Int32 => neg_checked!(Int32Type, array),
+ Int64 => neg_checked!(Int64Type, array),
+ Float16 => neg_wrapping!(Float16Type, array),
+ Float32 => neg_wrapping!(Float32Type, array),
+ Float64 => neg_wrapping!(Float64Type, array),
+ Decimal128(p, s) => {
+ let a = array
+ .as_primitive::<Decimal128Type>()
+ .try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?;
+
+ Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
+ }
+ Decimal256(p, s) => {
+ let a = array
+ .as_primitive::<Decimal256Type>()
+ .try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?;
+
+ Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
+ }
+ Duration(Second) => neg_checked!(DurationSecondType, array),
+ Duration(Millisecond) => neg_checked!(DurationMillisecondType, array),
+ Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array),
+ Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array),
+ Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array),
+ Interval(DayTime) => {
+ let a = array
+ .as_primitive::<IntervalDayTimeType>()
+ .try_unary::<_, IntervalDayTimeType, ArrowError>(|x| {
+ let (days, ms) = IntervalDayTimeType::to_parts(x);
+ Ok(IntervalDayTimeType::make_value(
+ days.neg_checked()?,
+ ms.neg_checked()?,
+ ))
+ })?;
+ Ok(Arc::new(a))
+ }
+ Interval(MonthDayNano) => {
+ let a = array
+ .as_primitive::<IntervalMonthDayNanoType>()
+ .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| {
+ let (months, days, nanos) =
IntervalMonthDayNanoType::to_parts(x);
+ Ok(IntervalMonthDayNanoType::make_value(
+ months.neg_checked()?,
+ days.neg_checked()?,
+ nanos.neg_checked()?,
+ ))
+ })?;
+ Ok(Arc::new(a))
+ }
+ t => Err(ArrowError::InvalidArgumentError(format!(
+ "Invalid arithmetic operation: !{t}"
+ ))),
+ }
+}
+
+/// Negates each element of `array`, wrapping on overflow for
[`DataType::is_integer`]
+pub fn neg_wrapping(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+ downcast_integer! {
+ array.data_type() => (neg_wrapping, array),
+ _ => neg(array),
+ }
+}
+
/// An enumeration of arithmetic operations
///
/// This allows sharing the type dispatch logic across the various kernels
@@ -670,3 +761,148 @@ fn decimal_op<T: DecimalType>(
Ok(Arc::new(array))
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow_buffer::{i256, ScalarBuffer};
+
+ fn test_neg_primitive<T: ArrowPrimitiveType>(
+ input: &[T::Native],
+ out: Result<&[T::Native], &str>,
+ ) {
+ let a = PrimitiveArray::<T>::new(ScalarBuffer::from(input.to_vec()),
None);
+ match out {
+ Ok(expected) => {
+ let result = neg(&a).unwrap();
+ assert_eq!(result.as_primitive::<T>().values(), expected);
+ }
+ Err(e) => {
+ let err = neg(&a).unwrap_err().to_string();
+ assert_eq!(e, err);
+ }
+ }
+ }
+
+ #[test]
+ fn test_neg() {
+ let input = &[1, -5, 2, 693, 3929];
+ let output = &[-1, 5, -2, -693, -3929];
+ test_neg_primitive::<Int32Type>(input, Ok(output));
+
+ let input = &[1, -5, 2, 693, 3929];
+ let output = &[-1, 5, -2, -693, -3929];
+ test_neg_primitive::<Int64Type>(input, Ok(output));
+ test_neg_primitive::<DurationSecondType>(input, Ok(output));
+ test_neg_primitive::<DurationMillisecondType>(input, Ok(output));
+ test_neg_primitive::<DurationMicrosecondType>(input, Ok(output));
+ test_neg_primitive::<DurationNanosecondType>(input, Ok(output));
+
+ let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5];
+ let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5];
+ test_neg_primitive::<Float32Type>(input, Ok(output));
+
+ test_neg_primitive::<Int32Type>(
+ &[i32::MIN],
+ Err("Compute error: Overflow happened on: -2147483648"),
+ );
+ test_neg_primitive::<Int64Type>(
+ &[i64::MIN],
+ Err("Compute error: Overflow happened on: -9223372036854775808"),
+ );
+ test_neg_primitive::<DurationSecondType>(
+ &[i64::MIN],
+ Err("Compute error: Overflow happened on: -9223372036854775808"),
+ );
+
+ let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap();
+ assert_eq!(r.as_primitive::<Int32Type>().value(0), i32::MIN);
+
+ let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap();
+ assert_eq!(r.as_primitive::<Int64Type>().value(0), i64::MIN);
+
+ let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN]))
+ .unwrap_err()
+ .to_string();
+
+ assert_eq!(
+ err,
+ "Compute error: Overflow happened on: -9223372036854775808"
+ );
+
+ let a = Decimal128Array::from(vec![1, 3, -44, 2, 4])
+ .with_precision_and_scale(9, 6)
+ .unwrap();
+
+ let r = neg(&a).unwrap();
+ assert_eq!(r.data_type(), a.data_type());
+ assert_eq!(
+ r.as_primitive::<Decimal128Type>().values(),
+ &[-1, -3, 44, -2, -4]
+ );
+
+ let a = Decimal256Array::from(vec![
+ i256::from_i128(342),
+ i256::from_i128(-4949),
+ i256::from_i128(3),
+ ])
+ .with_precision_and_scale(9, 6)
+ .unwrap();
+
+ let r = neg(&a).unwrap();
+ assert_eq!(r.data_type(), a.data_type());
+ assert_eq!(
+ r.as_primitive::<Decimal256Type>().values(),
+ &[
+ i256::from_i128(-342),
+ i256::from_i128(4949),
+ i256::from_i128(-3),
+ ]
+ );
+
+ let a = IntervalYearMonthArray::from(vec![
+ IntervalYearMonthType::make_value(2, 4),
+ IntervalYearMonthType::make_value(2, -4),
+ IntervalYearMonthType::make_value(-3, -5),
+ ]);
+ let r = neg(&a).unwrap();
+ assert_eq!(
+ r.as_primitive::<IntervalYearMonthType>().values(),
+ &[
+ IntervalYearMonthType::make_value(-2, -4),
+ IntervalYearMonthType::make_value(-2, 4),
+ IntervalYearMonthType::make_value(3, 5),
+ ]
+ );
+
+ let a = IntervalDayTimeArray::from(vec![
+ IntervalDayTimeType::make_value(2, 4),
+ IntervalDayTimeType::make_value(2, -4),
+ IntervalDayTimeType::make_value(-3, -5),
+ ]);
+ let r = neg(&a).unwrap();
+ assert_eq!(
+ r.as_primitive::<IntervalDayTimeType>().values(),
+ &[
+ IntervalDayTimeType::make_value(-2, -4),
+ IntervalDayTimeType::make_value(-2, 4),
+ IntervalDayTimeType::make_value(3, 5),
+ ]
+ );
+
+ let a = IntervalMonthDayNanoArray::from(vec![
+ IntervalMonthDayNanoType::make_value(2, 4, 5953394),
+ IntervalMonthDayNanoType::make_value(2, -4, -45839),
+ IntervalMonthDayNanoType::make_value(-3, -5, 6944),
+ ]);
+ let r = neg(&a).unwrap();
+ assert_eq!(
+ r.as_primitive::<IntervalMonthDayNanoType>().values(),
+ &[
+ IntervalMonthDayNanoType::make_value(-2, -4, -5953394),
+ IntervalMonthDayNanoType::make_value(-2, 4, 45839),
+ IntervalMonthDayNanoType::make_value(3, 5, -6944),
+ ]
+ );
+ }
+}
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index f99e6a8f6..0a65c64ad 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -1001,6 +1001,7 @@ impl IntervalYearMonthType {
///
/// * `years` - The number of years (+/-) represented in this interval
/// * `months` - The number of months (+/-) represented in this interval
+ #[inline]
pub fn make_value(
years: i32,
months: i32,
@@ -1015,6 +1016,7 @@ impl IntervalYearMonthType {
/// # Arguments
///
/// * `i` - The IntervalYearMonthType::Native to convert
+ #[inline]
pub fn to_months(i: <IntervalYearMonthType as ArrowPrimitiveType>::Native)
-> i32 {
i
}
@@ -1027,6 +1029,7 @@ impl IntervalDayTimeType {
///
/// * `days` - The number of days (+/-) represented in this interval
/// * `millis` - The number of milliseconds (+/-) represented in this
interval
+ #[inline]
pub fn make_value(
days: i32,
millis: i32,
@@ -1053,6 +1056,7 @@ impl IntervalDayTimeType {
/// # Arguments
///
/// * `i` - The IntervalDayTimeType to convert
+ #[inline]
pub fn to_parts(
i: <IntervalDayTimeType as ArrowPrimitiveType>::Native,
) -> (i32, i32) {
@@ -1070,6 +1074,7 @@ impl IntervalMonthDayNanoType {
/// * `months` - The number of months (+/-) represented in this interval
/// * `days` - The number of days (+/-) represented in this interval
/// * `nanos` - The number of nanoseconds (+/-) represented in this
interval
+ #[inline]
pub fn make_value(
months: i32,
days: i32,
@@ -1098,6 +1103,7 @@ impl IntervalMonthDayNanoType {
/// # Arguments
///
/// * `i` - The IntervalMonthDayNanoType to convert
+ #[inline]
pub fn to_parts(
i: <IntervalMonthDayNanoType as ArrowPrimitiveType>::Native,
) -> (i32, i32, i64) {