This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 20af94b0a Add negate kernels (#4488) (#4494)
20af94b0a is described below

commit 20af94b0acf8632e6512fad04b92e0602275d6ee
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Mon Jul 10 09:59:08 2023 -0400

    Add negate kernels (#4488) (#4494)
    
    * Add negate kernels (#4488)
    
    * Fix doc
    
    * Add Inteval tests
    
    * Review feedback
---
 arrow-arith/src/arithmetic.rs |   2 +
 arrow-arith/src/numeric.rs    | 236 ++++++++++++++++++++++++++++++++++++++++++
 arrow-array/src/types.rs      |   6 ++
 3 files changed, 244 insertions(+)

diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs
index 4f6ecc78d..4566afc2e 100644
--- a/arrow-arith/src/arithmetic.rs
+++ b/arrow-arith/src/arithmetic.rs
@@ -510,6 +510,7 @@ pub fn subtract_scalar_checked_dyn<T: ArrowNumericType>(
 ///
 /// This doesn't detect overflow. Once overflowing, the result will wrap 
around.
 /// For an overflow-checking variant, use `negate_checked` instead.
+#[deprecated(note = "Use arrow_arith::numeric::neg_wrapping")]
 pub fn negate<T: ArrowNumericType>(
     array: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>, ArrowError> {
@@ -520,6 +521,7 @@ pub fn negate<T: ArrowNumericType>(
 ///
 /// This detects overflow and returns an `Err` for that. For an 
non-overflow-checking variant,
 /// use `negate` instead.
+#[deprecated(note = "Use arrow_arith::numeric::neg")]
 pub fn negate_checked<T: ArrowNumericType>(
     array: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>, ArrowError> {
diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs
index 816fcaa94..c2e867dc9 100644
--- a/arrow-arith/src/numeric.rs
+++ b/arrow-arith/src/numeric.rs
@@ -74,6 +74,97 @@ pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> 
Result<ArrayRef, ArrowError> {
     arithmetic_op(Op::Rem, lhs, rhs)
 }
 
+macro_rules! neg_checked {
+    ($t:ty, $a:ident) => {{
+        let array = $a
+            .as_primitive::<$t>()
+            .try_unary::<_, $t, _>(|x| x.neg_checked())?;
+        Ok(Arc::new(array))
+    }};
+}
+
+macro_rules! neg_wrapping {
+    ($t:ty, $a:ident) => {{
+        let array = $a.as_primitive::<$t>().unary::<_, $t>(|x| 
x.neg_wrapping());
+        Ok(Arc::new(array))
+    }};
+}
+
+/// Negates each element of  `array`, returning an error on overflow
+///
+/// Note: negation of unsigned arrays is not supported and will return in an 
error,
+/// for wrapping unsigned negation consider using 
[`neg_wrapping`][neg_wrapping()]
+pub fn neg(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+
+    match array.data_type() {
+        Int8 => neg_checked!(Int8Type, array),
+        Int16 => neg_checked!(Int16Type, array),
+        Int32 => neg_checked!(Int32Type, array),
+        Int64 => neg_checked!(Int64Type, array),
+        Float16 => neg_wrapping!(Float16Type, array),
+        Float32 => neg_wrapping!(Float32Type, array),
+        Float64 => neg_wrapping!(Float64Type, array),
+        Decimal128(p, s) => {
+            let a = array
+                .as_primitive::<Decimal128Type>()
+                .try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?;
+
+            Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
+        }
+        Decimal256(p, s) => {
+            let a = array
+                .as_primitive::<Decimal256Type>()
+                .try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?;
+
+            Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
+        }
+        Duration(Second) => neg_checked!(DurationSecondType, array),
+        Duration(Millisecond) => neg_checked!(DurationMillisecondType, array),
+        Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array),
+        Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array),
+        Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array),
+        Interval(DayTime) => {
+            let a = array
+                .as_primitive::<IntervalDayTimeType>()
+                .try_unary::<_, IntervalDayTimeType, ArrowError>(|x| {
+                    let (days, ms) = IntervalDayTimeType::to_parts(x);
+                    Ok(IntervalDayTimeType::make_value(
+                        days.neg_checked()?,
+                        ms.neg_checked()?,
+                    ))
+                })?;
+            Ok(Arc::new(a))
+        }
+        Interval(MonthDayNano) => {
+            let a = array
+                .as_primitive::<IntervalMonthDayNanoType>()
+                .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| {
+                let (months, days, nanos) = 
IntervalMonthDayNanoType::to_parts(x);
+                Ok(IntervalMonthDayNanoType::make_value(
+                    months.neg_checked()?,
+                    days.neg_checked()?,
+                    nanos.neg_checked()?,
+                ))
+            })?;
+            Ok(Arc::new(a))
+        }
+        t => Err(ArrowError::InvalidArgumentError(format!(
+            "Invalid arithmetic operation: !{t}"
+        ))),
+    }
+}
+
+/// Negates each element of  `array`, wrapping on overflow for 
[`DataType::is_integer`]
+pub fn neg_wrapping(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+    downcast_integer! {
+        array.data_type() => (neg_wrapping, array),
+        _ => neg(array),
+    }
+}
+
 /// An enumeration of arithmetic operations
 ///
 /// This allows sharing the type dispatch logic across the various kernels
@@ -670,3 +761,148 @@ fn decimal_op<T: DecimalType>(
 
     Ok(Arc::new(array))
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow_buffer::{i256, ScalarBuffer};
+
+    fn test_neg_primitive<T: ArrowPrimitiveType>(
+        input: &[T::Native],
+        out: Result<&[T::Native], &str>,
+    ) {
+        let a = PrimitiveArray::<T>::new(ScalarBuffer::from(input.to_vec()), 
None);
+        match out {
+            Ok(expected) => {
+                let result = neg(&a).unwrap();
+                assert_eq!(result.as_primitive::<T>().values(), expected);
+            }
+            Err(e) => {
+                let err = neg(&a).unwrap_err().to_string();
+                assert_eq!(e, err);
+            }
+        }
+    }
+
+    #[test]
+    fn test_neg() {
+        let input = &[1, -5, 2, 693, 3929];
+        let output = &[-1, 5, -2, -693, -3929];
+        test_neg_primitive::<Int32Type>(input, Ok(output));
+
+        let input = &[1, -5, 2, 693, 3929];
+        let output = &[-1, 5, -2, -693, -3929];
+        test_neg_primitive::<Int64Type>(input, Ok(output));
+        test_neg_primitive::<DurationSecondType>(input, Ok(output));
+        test_neg_primitive::<DurationMillisecondType>(input, Ok(output));
+        test_neg_primitive::<DurationMicrosecondType>(input, Ok(output));
+        test_neg_primitive::<DurationNanosecondType>(input, Ok(output));
+
+        let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5];
+        let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5];
+        test_neg_primitive::<Float32Type>(input, Ok(output));
+
+        test_neg_primitive::<Int32Type>(
+            &[i32::MIN],
+            Err("Compute error: Overflow happened on: -2147483648"),
+        );
+        test_neg_primitive::<Int64Type>(
+            &[i64::MIN],
+            Err("Compute error: Overflow happened on: -9223372036854775808"),
+        );
+        test_neg_primitive::<DurationSecondType>(
+            &[i64::MIN],
+            Err("Compute error: Overflow happened on: -9223372036854775808"),
+        );
+
+        let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap();
+        assert_eq!(r.as_primitive::<Int32Type>().value(0), i32::MIN);
+
+        let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap();
+        assert_eq!(r.as_primitive::<Int64Type>().value(0), i64::MIN);
+
+        let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN]))
+            .unwrap_err()
+            .to_string();
+
+        assert_eq!(
+            err,
+            "Compute error: Overflow happened on: -9223372036854775808"
+        );
+
+        let a = Decimal128Array::from(vec![1, 3, -44, 2, 4])
+            .with_precision_and_scale(9, 6)
+            .unwrap();
+
+        let r = neg(&a).unwrap();
+        assert_eq!(r.data_type(), a.data_type());
+        assert_eq!(
+            r.as_primitive::<Decimal128Type>().values(),
+            &[-1, -3, 44, -2, -4]
+        );
+
+        let a = Decimal256Array::from(vec![
+            i256::from_i128(342),
+            i256::from_i128(-4949),
+            i256::from_i128(3),
+        ])
+        .with_precision_and_scale(9, 6)
+        .unwrap();
+
+        let r = neg(&a).unwrap();
+        assert_eq!(r.data_type(), a.data_type());
+        assert_eq!(
+            r.as_primitive::<Decimal256Type>().values(),
+            &[
+                i256::from_i128(-342),
+                i256::from_i128(4949),
+                i256::from_i128(-3),
+            ]
+        );
+
+        let a = IntervalYearMonthArray::from(vec![
+            IntervalYearMonthType::make_value(2, 4),
+            IntervalYearMonthType::make_value(2, -4),
+            IntervalYearMonthType::make_value(-3, -5),
+        ]);
+        let r = neg(&a).unwrap();
+        assert_eq!(
+            r.as_primitive::<IntervalYearMonthType>().values(),
+            &[
+                IntervalYearMonthType::make_value(-2, -4),
+                IntervalYearMonthType::make_value(-2, 4),
+                IntervalYearMonthType::make_value(3, 5),
+            ]
+        );
+
+        let a = IntervalDayTimeArray::from(vec![
+            IntervalDayTimeType::make_value(2, 4),
+            IntervalDayTimeType::make_value(2, -4),
+            IntervalDayTimeType::make_value(-3, -5),
+        ]);
+        let r = neg(&a).unwrap();
+        assert_eq!(
+            r.as_primitive::<IntervalDayTimeType>().values(),
+            &[
+                IntervalDayTimeType::make_value(-2, -4),
+                IntervalDayTimeType::make_value(-2, 4),
+                IntervalDayTimeType::make_value(3, 5),
+            ]
+        );
+
+        let a = IntervalMonthDayNanoArray::from(vec![
+            IntervalMonthDayNanoType::make_value(2, 4, 5953394),
+            IntervalMonthDayNanoType::make_value(2, -4, -45839),
+            IntervalMonthDayNanoType::make_value(-3, -5, 6944),
+        ]);
+        let r = neg(&a).unwrap();
+        assert_eq!(
+            r.as_primitive::<IntervalMonthDayNanoType>().values(),
+            &[
+                IntervalMonthDayNanoType::make_value(-2, -4, -5953394),
+                IntervalMonthDayNanoType::make_value(-2, 4, 45839),
+                IntervalMonthDayNanoType::make_value(3, 5, -6944),
+            ]
+        );
+    }
+}
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index f99e6a8f6..0a65c64ad 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -1001,6 +1001,7 @@ impl IntervalYearMonthType {
     ///
     /// * `years` - The number of years (+/-) represented in this interval
     /// * `months` - The number of months (+/-) represented in this interval
+    #[inline]
     pub fn make_value(
         years: i32,
         months: i32,
@@ -1015,6 +1016,7 @@ impl IntervalYearMonthType {
     /// # Arguments
     ///
     /// * `i` - The IntervalYearMonthType::Native to convert
+    #[inline]
     pub fn to_months(i: <IntervalYearMonthType as ArrowPrimitiveType>::Native) 
-> i32 {
         i
     }
@@ -1027,6 +1029,7 @@ impl IntervalDayTimeType {
     ///
     /// * `days` - The number of days (+/-) represented in this interval
     /// * `millis` - The number of milliseconds (+/-) represented in this 
interval
+    #[inline]
     pub fn make_value(
         days: i32,
         millis: i32,
@@ -1053,6 +1056,7 @@ impl IntervalDayTimeType {
     /// # Arguments
     ///
     /// * `i` - The IntervalDayTimeType to convert
+    #[inline]
     pub fn to_parts(
         i: <IntervalDayTimeType as ArrowPrimitiveType>::Native,
     ) -> (i32, i32) {
@@ -1070,6 +1074,7 @@ impl IntervalMonthDayNanoType {
     /// * `months` - The number of months (+/-) represented in this interval
     /// * `days` - The number of days (+/-) represented in this interval
     /// * `nanos` - The number of nanoseconds (+/-) represented in this 
interval
+    #[inline]
     pub fn make_value(
         months: i32,
         days: i32,
@@ -1098,6 +1103,7 @@ impl IntervalMonthDayNanoType {
     /// # Arguments
     ///
     /// * `i` - The IntervalMonthDayNanoType to convert
+    #[inline]
     pub fn to_parts(
         i: <IntervalMonthDayNanoType as ArrowPrimitiveType>::Native,
     ) -> (i32, i32, i64) {

Reply via email to