This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new ed16d9f8c Arrow Arithmetic: Subtract timestamps (#4244)
ed16d9f8c is described below

commit ed16d9f8c0dc29b1019d20cfde8b874c22dd838d
Author: Josh Wiley <[email protected]>
AuthorDate: Fri May 19 01:25:39 2023 -0700

    Arrow Arithmetic: Subtract timestamps (#4244)
    
    * feat(arith): subtract timestamps
    
    * feat(arith): checked and unchecked subtraction for timestamps
    
    * feat(arith): use closure for ts sub
---
 arrow-arith/src/arithmetic.rs | 218 +++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 216 insertions(+), 2 deletions(-)

diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs
index 42f6e3974..c3c5cb864 100644
--- a/arrow-arith/src/arithmetic.rs
+++ b/arrow-arith/src/arithmetic.rs
@@ -1096,13 +1096,17 @@ pub fn subtract_dyn(left: &dyn Array, right: &dyn 
Array) -> Result<ArrayRef, Arr
                     let res = math_checked_op(l, r, 
TimestampSecondType::subtract_month_day_nano)?;
                     Ok(Arc::new(res.with_timezone_opt(l.timezone())))
                 }
+                DataType::Timestamp(TimeUnit::Second, _) => {
+                    let r = right.as_primitive::<TimestampSecondType>();
+                    let res: PrimitiveArray<DurationSecondType> = binary(l, r, 
|a, b| a.wrapping_sub(b))?;
+                    Ok(Arc::new(res))
+                }
                 _ => Err(ArrowError::CastError(format!(
                     "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
                     left.data_type(), right.data_type()
                 ))),
             }
         }
-
         DataType::Timestamp(TimeUnit::Microsecond, _) => {
             let l = left.as_primitive::<TimestampMicrosecondType>();
             match right.data_type() {
@@ -1121,6 +1125,11 @@ pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) 
-> Result<ArrayRef, Arr
                     let res = math_checked_op(l, r, 
TimestampMicrosecondType::subtract_month_day_nano)?;
                     Ok(Arc::new(res.with_timezone_opt(l.timezone())))
                 }
+                DataType::Timestamp(TimeUnit::Microsecond, _) => {
+                    let r = right.as_primitive::<TimestampMicrosecondType>();
+                    let res: PrimitiveArray<DurationMicrosecondType> = 
binary(l, r, |a, b| a.wrapping_sub(b))?;
+                    Ok(Arc::new(res))
+                }
                 _ => Err(ArrowError::CastError(format!(
                     "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
                     left.data_type(), right.data_type()
@@ -1145,13 +1154,17 @@ pub fn subtract_dyn(left: &dyn Array, right: &dyn 
Array) -> Result<ArrayRef, Arr
                     let res = math_checked_op(l, r, 
TimestampMillisecondType::subtract_month_day_nano)?;
                     Ok(Arc::new(res.with_timezone_opt(l.timezone())))
                 }
+                DataType::Timestamp(TimeUnit::Millisecond, _) => {
+                    let r = right.as_primitive::<TimestampMillisecondType>();
+                    let res: PrimitiveArray<DurationMillisecondType> = 
binary(l, r, |a, b| a.wrapping_sub(b))?;
+                    Ok(Arc::new(res))
+                }
                 _ => Err(ArrowError::CastError(format!(
                     "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
                     left.data_type(), right.data_type()
                 ))),
             }
         }
-
         DataType::Timestamp(TimeUnit::Nanosecond, _) => {
             let l = left.as_primitive::<TimestampNanosecondType>();
             match right.data_type() {
@@ -1170,6 +1183,11 @@ pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) 
-> Result<ArrayRef, Arr
                     let res = math_checked_op(l, r, 
TimestampNanosecondType::subtract_month_day_nano)?;
                     Ok(Arc::new(res.with_timezone_opt(l.timezone())))
                 }
+                DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+                    let r = right.as_primitive::<TimestampNanosecondType>();
+                    let res: PrimitiveArray<DurationNanosecondType> = 
binary(l, r, |a, b| a.wrapping_sub(b))?;
+                    Ok(Arc::new(res))
+                }
                 _ => Err(ArrowError::CastError(format!(
                     "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
                     left.data_type(), right.data_type()
@@ -1256,6 +1274,62 @@ pub fn subtract_dyn_checked(
                 ))),
             }
         }
+        DataType::Timestamp(TimeUnit::Second, _) => {
+            let l = left.as_primitive::<TimestampSecondType>();
+            match right.data_type() {
+                DataType::Timestamp(TimeUnit::Second, _) => {
+                    let r = right.as_primitive::<TimestampSecondType>();
+                    let res: PrimitiveArray<DurationSecondType> = 
try_binary(l, r, |a, b| a.sub_checked(b))?;
+                    Ok(Arc::new(res))
+                }
+                _ => Err(ArrowError::CastError(format!(
+                    "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
+                    left.data_type(), right.data_type()
+                ))),
+            }
+        }
+        DataType::Timestamp(TimeUnit::Microsecond, _) => {
+            let l = left.as_primitive::<TimestampMicrosecondType>();
+            match right.data_type() {
+                DataType::Timestamp(TimeUnit::Microsecond, _) => {
+                    let r = right.as_primitive::<TimestampMicrosecondType>();
+                    let res: PrimitiveArray<DurationMicrosecondType> = 
try_binary(l, r, |a, b| a.sub_checked(b))?;
+                    Ok(Arc::new(res))
+                }
+                _ => Err(ArrowError::CastError(format!(
+                    "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
+                    left.data_type(), right.data_type()
+                ))),
+            }
+        }
+        DataType::Timestamp(TimeUnit::Millisecond, _) => {
+            let l = left.as_primitive::<TimestampMillisecondType>();
+            match right.data_type() {
+                DataType::Timestamp(TimeUnit::Millisecond, _) => {
+                    let r = right.as_primitive::<TimestampMillisecondType>();
+                    let res: PrimitiveArray<DurationMillisecondType> = 
try_binary(l, r, |a, b| a.sub_checked(b))?;
+                    Ok(Arc::new(res))
+                }
+                _ => Err(ArrowError::CastError(format!(
+                    "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
+                    left.data_type(), right.data_type()
+                ))),
+            }
+        }
+        DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+            let l = left.as_primitive::<TimestampNanosecondType>();
+            match right.data_type() {
+                DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+                    let r = right.as_primitive::<TimestampNanosecondType>();
+                    let res: PrimitiveArray<DurationNanosecondType> = 
try_binary(l, r, |a, b| a.sub_checked(b))?;
+                    Ok(Arc::new(res))
+                }
+                _ => Err(ArrowError::CastError(format!(
+                    "Cannot perform arithmetic operation between array of type 
{} and array of type {}",
+                    left.data_type(), right.data_type()
+                ))),
+            }
+        }
         _ => {
             downcast_primitive_array!(
                 (left, right) => {
@@ -4649,4 +4723,144 @@ mod tests {
         ]);
         assert_eq!(&expected, result);
     }
+
+    #[test]
+    fn test_timestamp_second_subtract_timestamp() {
+        let a = TimestampSecondArray::from(vec![0, 2, 4, 6, 8]);
+        let b = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]);
+        let expected = DurationSecondArray::from(vec![-1, 0, 1, 2, 3]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationSecondType>();
+        assert_eq!(&expected, result);
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationSecondType>();
+        assert_eq!(&expected, result);
+    }
+
+    #[test]
+    fn test_timestamp_second_subtract_timestamp_overflow() {
+        let a = TimestampSecondArray::from(vec![
+            <TimestampSecondType as ArrowPrimitiveType>::Native::MAX,
+        ]);
+        let b = TimestampSecondArray::from(vec![
+            <TimestampSecondType as ArrowPrimitiveType>::Native::MIN,
+        ]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b);
+        assert!(!&result.is_err());
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b);
+        assert!(&result.is_err());
+    }
+
+    #[test]
+    fn test_timestamp_microsecond_subtract_timestamp() {
+        let a = TimestampMicrosecondArray::from(vec![0, 2, 4, 6, 8]);
+        let b = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]);
+        let expected = DurationMicrosecondArray::from(vec![-1, 0, 1, 2, 3]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationMicrosecondType>();
+        assert_eq!(&expected, result);
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationMicrosecondType>();
+        assert_eq!(&expected, result);
+    }
+
+    #[test]
+    fn test_timestamp_microsecond_subtract_timestamp_overflow() {
+        let a = TimestampMicrosecondArray::from(vec![
+            <TimestampMicrosecondType as ArrowPrimitiveType>::Native::MAX,
+        ]);
+        let b = TimestampMicrosecondArray::from(vec![
+            <TimestampMicrosecondType as ArrowPrimitiveType>::Native::MIN,
+        ]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b);
+        assert!(!&result.is_err());
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b);
+        assert!(&result.is_err());
+    }
+
+    #[test]
+    fn test_timestamp_millisecond_subtract_timestamp() {
+        let a = TimestampMillisecondArray::from(vec![0, 2, 4, 6, 8]);
+        let b = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]);
+        let expected = DurationMillisecondArray::from(vec![-1, 0, 1, 2, 3]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationMillisecondType>();
+        assert_eq!(&expected, result);
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationMillisecondType>();
+        assert_eq!(&expected, result);
+    }
+
+    #[test]
+    fn test_timestamp_millisecond_subtract_timestamp_overflow() {
+        let a = TimestampMillisecondArray::from(vec![
+            <TimestampMillisecondType as ArrowPrimitiveType>::Native::MAX,
+        ]);
+        let b = TimestampMillisecondArray::from(vec![
+            <TimestampMillisecondType as ArrowPrimitiveType>::Native::MIN,
+        ]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b);
+        assert!(!&result.is_err());
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b);
+        assert!(&result.is_err());
+    }
+
+    #[test]
+    fn test_timestamp_nanosecond_subtract_timestamp() {
+        let a = TimestampNanosecondArray::from(vec![0, 2, 4, 6, 8]);
+        let b = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]);
+        let expected = DurationNanosecondArray::from(vec![-1, 0, 1, 2, 3]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationNanosecondType>();
+        assert_eq!(&expected, result);
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b).unwrap();
+        let result = result.as_primitive::<DurationNanosecondType>();
+        assert_eq!(&expected, result);
+    }
+
+    #[test]
+    fn test_timestamp_nanosecond_subtract_timestamp_overflow() {
+        let a = TimestampNanosecondArray::from(vec![
+            <TimestampNanosecondType as ArrowPrimitiveType>::Native::MAX,
+        ]);
+        let b = TimestampNanosecondArray::from(vec![
+            <TimestampNanosecondType as ArrowPrimitiveType>::Native::MIN,
+        ]);
+
+        // unchecked
+        let result = subtract_dyn(&a, &b);
+        assert!(!&result.is_err());
+
+        // checked
+        let result = subtract_dyn_checked(&a, &b);
+        assert!(&result.is_err());
+    }
 }

Reply via email to