This is an automated email from the ASF dual-hosted git repository.

wayne pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 2881dbe42a Support cast between Durations + between Durations all 
numeric types (#6452)
2881dbe42a is described below

commit 2881dbe42ad33c5d0bc7f9c6661ba55127511ea1
Author: tison <[email protected]>
AuthorDate: Thu Sep 26 11:03:12 2024 +0800

    Support cast between Durations + between Durations all numeric types (#6452)
    
    * Support cast between Durations
    
    Signed-off-by: tison <[email protected]>
    
    * Support cast between Durations and all numeric type
    
    Signed-off-by: tison <[email protected]>
    
    * Impl cast between Durations
    
    Signed-off-by: tison <[email protected]>
    
    * Add test_cast_between_durations
    
    Signed-off-by: tison <[email protected]>
    
    * add test cases
    
    Signed-off-by: tison <[email protected]>
    
    * cargo clippy
    
    Signed-off-by: tison <[email protected]>
    
    ---------
    
    Signed-off-by: tison <[email protected]>
---
 arrow-cast/src/cast/mod.rs | 174 +++++++++++++++++++++++++++++++++++++++------
 1 file changed, 153 insertions(+), 21 deletions(-)

diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs
index b751c81ee4..e3fad3da19 100644
--- a/arrow-cast/src/cast/mod.rs
+++ b/arrow-cast/src/cast/mod.rs
@@ -271,8 +271,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
             | Time64(Microsecond)
             | Time64(Nanosecond),
         ) => true,
-        (Int64, Duration(_)) => true,
-        (Duration(_), Int64) => true,
+        (_, Duration(_)) if from_type.is_numeric() => true,
+        (Duration(_), _) if to_type.is_numeric() => true,
+        (Duration(_), Duration(_)) => true,
         (Interval(from_type), Int64) => {
             match from_type {
                 YearMonth => true,
@@ -518,6 +519,15 @@ fn make_timestamp_array(
     }
 }
 
+fn make_duration_array(array: &PrimitiveArray<Int64Type>, unit: TimeUnit) -> 
ArrayRef {
+    match unit {
+        TimeUnit::Second => 
Arc::new(array.reinterpret_cast::<DurationSecondType>()),
+        TimeUnit::Millisecond => 
Arc::new(array.reinterpret_cast::<DurationMillisecondType>()),
+        TimeUnit::Microsecond => 
Arc::new(array.reinterpret_cast::<DurationMicrosecondType>()),
+        TimeUnit::Nanosecond => 
Arc::new(array.reinterpret_cast::<DurationNanosecondType>()),
+    }
+}
+
 fn as_time_res_with_timezone<T: ArrowPrimitiveType>(
     v: i64,
     tz: Option<Tz>,
@@ -2074,31 +2084,53 @@ pub fn cast_with_options(
                 .as_primitive::<Date32Type>()
                 .unary::<_, TimestampNanosecondType>(|x| (x as i64) * 
NANOSECONDS_IN_DAY),
         )),
-        (Int64, Duration(TimeUnit::Second)) => {
-            cast_reinterpret_arrays::<Int64Type, DurationSecondType>(array)
-        }
-        (Int64, Duration(TimeUnit::Millisecond)) => {
-            cast_reinterpret_arrays::<Int64Type, 
DurationMillisecondType>(array)
-        }
-        (Int64, Duration(TimeUnit::Microsecond)) => {
-            cast_reinterpret_arrays::<Int64Type, 
DurationMicrosecondType>(array)
+
+        (_, Duration(unit)) if from_type.is_numeric() => {
+            let array = cast_with_options(array, &Int64, cast_options)?;
+            Ok(make_duration_array(array.as_primitive(), *unit))
         }
-        (Int64, Duration(TimeUnit::Nanosecond)) => {
-            cast_reinterpret_arrays::<Int64Type, DurationNanosecondType>(array)
+        (Duration(TimeUnit::Second), _) if to_type.is_numeric() => {
+            let array = cast_reinterpret_arrays::<DurationSecondType, 
Int64Type>(array)?;
+            cast_with_options(&array, to_type, cast_options)
         }
-
-        (Duration(TimeUnit::Second), Int64) => {
-            cast_reinterpret_arrays::<DurationSecondType, Int64Type>(array)
+        (Duration(TimeUnit::Millisecond), _) if to_type.is_numeric() => {
+            let array = cast_reinterpret_arrays::<DurationMillisecondType, 
Int64Type>(array)?;
+            cast_with_options(&array, to_type, cast_options)
         }
-        (Duration(TimeUnit::Millisecond), Int64) => {
-            cast_reinterpret_arrays::<DurationMillisecondType, 
Int64Type>(array)
+        (Duration(TimeUnit::Microsecond), _) if to_type.is_numeric() => {
+            let array = cast_reinterpret_arrays::<DurationMicrosecondType, 
Int64Type>(array)?;
+            cast_with_options(&array, to_type, cast_options)
         }
-        (Duration(TimeUnit::Microsecond), Int64) => {
-            cast_reinterpret_arrays::<DurationMicrosecondType, 
Int64Type>(array)
+        (Duration(TimeUnit::Nanosecond), _) if to_type.is_numeric() => {
+            let array = cast_reinterpret_arrays::<DurationNanosecondType, 
Int64Type>(array)?;
+            cast_with_options(&array, to_type, cast_options)
         }
-        (Duration(TimeUnit::Nanosecond), Int64) => {
-            cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)
+
+        (Duration(from_unit), Duration(to_unit)) => {
+            let array = cast_with_options(array, &Int64, cast_options)?;
+            let time_array = array.as_primitive::<Int64Type>();
+            let from_size = time_unit_multiple(from_unit);
+            let to_size = time_unit_multiple(to_unit);
+            // we either divide or multiply, depending on size of each unit
+            // units are never the same when the types are the same
+            let converted = match from_size.cmp(&to_size) {
+                Ordering::Greater => {
+                    let divisor = from_size / to_size;
+                    time_array.unary::<_, Int64Type>(|o| o / divisor)
+                }
+                Ordering::Equal => time_array.clone(),
+                Ordering::Less => {
+                    let mul = to_size / from_size;
+                    if cast_options.safe {
+                        time_array.unary_opt::<_, Int64Type>(|o| 
o.checked_mul(mul))
+                    } else {
+                        time_array.try_unary::<_, Int64Type, _>(|o| 
o.mul_checked(mul))?
+                    }
+                }
+            };
+            Ok(make_duration_array(&converted, *to_unit))
         }
+
         (Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => {
             cast_duration_to_interval::<DurationSecondType>(array, 
cast_options)
         }
@@ -5254,6 +5286,106 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_cast_between_durations_and_numerics() {
+        fn test_cast_between_durations<FromType, ToType>()
+        where
+            FromType: ArrowPrimitiveType<Native = i64>,
+            ToType: ArrowPrimitiveType<Native = i64>,
+            PrimitiveArray<FromType>: From<Vec<Option<i64>>>,
+        {
+            let from_unit = match FromType::DATA_TYPE {
+                DataType::Duration(unit) => unit,
+                _ => panic!("Expected a duration type"),
+            };
+            let to_unit = match ToType::DATA_TYPE {
+                DataType::Duration(unit) => unit,
+                _ => panic!("Expected a duration type"),
+            };
+            let from_size = time_unit_multiple(&from_unit);
+            let to_size = time_unit_multiple(&to_unit);
+
+            let (v1_before, v2_before) = (8640003005, 1696002001);
+            let (v1_after, v2_after) = if from_size >= to_size {
+                (
+                    v1_before / (from_size / to_size),
+                    v2_before / (from_size / to_size),
+                )
+            } else {
+                (
+                    v1_before * (to_size / from_size),
+                    v2_before * (to_size / from_size),
+                )
+            };
+
+            let array =
+                PrimitiveArray::<FromType>::from(vec![Some(v1_before), 
Some(v2_before), None]);
+            let b = cast(&array, &ToType::DATA_TYPE).unwrap();
+            let c = b.as_primitive::<ToType>();
+            assert_eq!(v1_after, c.value(0));
+            assert_eq!(v2_after, c.value(1));
+            assert!(c.is_null(2));
+        }
+
+        // between each individual duration type
+        test_cast_between_durations::<DurationSecondType, 
DurationMillisecondType>();
+        test_cast_between_durations::<DurationSecondType, 
DurationMicrosecondType>();
+        test_cast_between_durations::<DurationSecondType, 
DurationNanosecondType>();
+        test_cast_between_durations::<DurationMillisecondType, 
DurationSecondType>();
+        test_cast_between_durations::<DurationMillisecondType, 
DurationMicrosecondType>();
+        test_cast_between_durations::<DurationMillisecondType, 
DurationNanosecondType>();
+        test_cast_between_durations::<DurationMicrosecondType, 
DurationSecondType>();
+        test_cast_between_durations::<DurationMicrosecondType, 
DurationMillisecondType>();
+        test_cast_between_durations::<DurationMicrosecondType, 
DurationNanosecondType>();
+        test_cast_between_durations::<DurationNanosecondType, 
DurationSecondType>();
+        test_cast_between_durations::<DurationNanosecondType, 
DurationMillisecondType>();
+        test_cast_between_durations::<DurationNanosecondType, 
DurationMicrosecondType>();
+
+        // cast failed
+        let array = DurationSecondArray::from(vec![
+            Some(i64::MAX),
+            Some(8640203410378005),
+            Some(10241096),
+            None,
+        ]);
+        let b = cast(&array, 
&DataType::Duration(TimeUnit::Nanosecond)).unwrap();
+        let c = b.as_primitive::<DurationNanosecondType>();
+        assert!(c.is_null(0));
+        assert!(c.is_null(1));
+        assert_eq!(10241096000000000, c.value(2));
+        assert!(c.is_null(3));
+
+        // durations to numerics
+        let array = DurationSecondArray::from(vec![
+            Some(i64::MAX),
+            Some(8640203410378005),
+            Some(10241096),
+            None,
+        ]);
+        let b = cast(&array, &DataType::Int64).unwrap();
+        let c = b.as_primitive::<Int64Type>();
+        assert_eq!(i64::MAX, c.value(0));
+        assert_eq!(8640203410378005, c.value(1));
+        assert_eq!(10241096, c.value(2));
+        assert!(c.is_null(3));
+
+        let b = cast(&array, &DataType::Int32).unwrap();
+        let c = b.as_primitive::<Int32Type>();
+        assert_eq!(0, c.value(0));
+        assert_eq!(0, c.value(1));
+        assert_eq!(10241096, c.value(2));
+        assert!(c.is_null(3));
+
+        // numerics to durations
+        let array = Int32Array::from(vec![Some(i32::MAX), Some(802034103), 
Some(10241096), None]);
+        let b = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap();
+        let c = b.as_any().downcast_ref::<DurationSecondArray>().unwrap();
+        assert_eq!(i32::MAX as i64, c.value(0));
+        assert_eq!(802034103, c.value(1));
+        assert_eq!(10241096, c.value(2));
+        assert!(c.is_null(3));
+    }
+
     #[test]
     fn test_cast_to_strings() {
         let a = Int32Array::from(vec![1, 2, 3]);

Reply via email to