This is an automated email from the ASF dual-hosted git repository.
wayne pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 2881dbe42a Support cast between Durations + between Durations all
numeric types (#6452)
2881dbe42a is described below
commit 2881dbe42ad33c5d0bc7f9c6661ba55127511ea1
Author: tison <[email protected]>
AuthorDate: Thu Sep 26 11:03:12 2024 +0800
Support cast between Durations + between Durations all numeric types (#6452)
* Support cast between Durations
Signed-off-by: tison <[email protected]>
* Support cast between Durations and all numeric type
Signed-off-by: tison <[email protected]>
* Impl cast between Durations
Signed-off-by: tison <[email protected]>
* Add test_cast_between_durations
Signed-off-by: tison <[email protected]>
* add test cases
Signed-off-by: tison <[email protected]>
* cargo clippy
Signed-off-by: tison <[email protected]>
---------
Signed-off-by: tison <[email protected]>
---
arrow-cast/src/cast/mod.rs | 174 +++++++++++++++++++++++++++++++++++++++------
1 file changed, 153 insertions(+), 21 deletions(-)
diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs
index b751c81ee4..e3fad3da19 100644
--- a/arrow-cast/src/cast/mod.rs
+++ b/arrow-cast/src/cast/mod.rs
@@ -271,8 +271,9 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
| Time64(Microsecond)
| Time64(Nanosecond),
) => true,
- (Int64, Duration(_)) => true,
- (Duration(_), Int64) => true,
+ (_, Duration(_)) if from_type.is_numeric() => true,
+ (Duration(_), _) if to_type.is_numeric() => true,
+ (Duration(_), Duration(_)) => true,
(Interval(from_type), Int64) => {
match from_type {
YearMonth => true,
@@ -518,6 +519,15 @@ fn make_timestamp_array(
}
}
+fn make_duration_array(array: &PrimitiveArray<Int64Type>, unit: TimeUnit) ->
ArrayRef {
+ match unit {
+ TimeUnit::Second =>
Arc::new(array.reinterpret_cast::<DurationSecondType>()),
+ TimeUnit::Millisecond =>
Arc::new(array.reinterpret_cast::<DurationMillisecondType>()),
+ TimeUnit::Microsecond =>
Arc::new(array.reinterpret_cast::<DurationMicrosecondType>()),
+ TimeUnit::Nanosecond =>
Arc::new(array.reinterpret_cast::<DurationNanosecondType>()),
+ }
+}
+
fn as_time_res_with_timezone<T: ArrowPrimitiveType>(
v: i64,
tz: Option<Tz>,
@@ -2074,31 +2084,53 @@ pub fn cast_with_options(
.as_primitive::<Date32Type>()
.unary::<_, TimestampNanosecondType>(|x| (x as i64) *
NANOSECONDS_IN_DAY),
)),
- (Int64, Duration(TimeUnit::Second)) => {
- cast_reinterpret_arrays::<Int64Type, DurationSecondType>(array)
- }
- (Int64, Duration(TimeUnit::Millisecond)) => {
- cast_reinterpret_arrays::<Int64Type,
DurationMillisecondType>(array)
- }
- (Int64, Duration(TimeUnit::Microsecond)) => {
- cast_reinterpret_arrays::<Int64Type,
DurationMicrosecondType>(array)
+
+ (_, Duration(unit)) if from_type.is_numeric() => {
+ let array = cast_with_options(array, &Int64, cast_options)?;
+ Ok(make_duration_array(array.as_primitive(), *unit))
}
- (Int64, Duration(TimeUnit::Nanosecond)) => {
- cast_reinterpret_arrays::<Int64Type, DurationNanosecondType>(array)
+ (Duration(TimeUnit::Second), _) if to_type.is_numeric() => {
+ let array = cast_reinterpret_arrays::<DurationSecondType,
Int64Type>(array)?;
+ cast_with_options(&array, to_type, cast_options)
}
-
- (Duration(TimeUnit::Second), Int64) => {
- cast_reinterpret_arrays::<DurationSecondType, Int64Type>(array)
+ (Duration(TimeUnit::Millisecond), _) if to_type.is_numeric() => {
+ let array = cast_reinterpret_arrays::<DurationMillisecondType,
Int64Type>(array)?;
+ cast_with_options(&array, to_type, cast_options)
}
- (Duration(TimeUnit::Millisecond), Int64) => {
- cast_reinterpret_arrays::<DurationMillisecondType,
Int64Type>(array)
+ (Duration(TimeUnit::Microsecond), _) if to_type.is_numeric() => {
+ let array = cast_reinterpret_arrays::<DurationMicrosecondType,
Int64Type>(array)?;
+ cast_with_options(&array, to_type, cast_options)
}
- (Duration(TimeUnit::Microsecond), Int64) => {
- cast_reinterpret_arrays::<DurationMicrosecondType,
Int64Type>(array)
+ (Duration(TimeUnit::Nanosecond), _) if to_type.is_numeric() => {
+ let array = cast_reinterpret_arrays::<DurationNanosecondType,
Int64Type>(array)?;
+ cast_with_options(&array, to_type, cast_options)
}
- (Duration(TimeUnit::Nanosecond), Int64) => {
- cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)
+
+ (Duration(from_unit), Duration(to_unit)) => {
+ let array = cast_with_options(array, &Int64, cast_options)?;
+ let time_array = array.as_primitive::<Int64Type>();
+ let from_size = time_unit_multiple(from_unit);
+ let to_size = time_unit_multiple(to_unit);
+ // we either divide or multiply, depending on size of each unit
+ // units are never the same when the types are the same
+ let converted = match from_size.cmp(&to_size) {
+ Ordering::Greater => {
+ let divisor = from_size / to_size;
+ time_array.unary::<_, Int64Type>(|o| o / divisor)
+ }
+ Ordering::Equal => time_array.clone(),
+ Ordering::Less => {
+ let mul = to_size / from_size;
+ if cast_options.safe {
+ time_array.unary_opt::<_, Int64Type>(|o|
o.checked_mul(mul))
+ } else {
+ time_array.try_unary::<_, Int64Type, _>(|o|
o.mul_checked(mul))?
+ }
+ }
+ };
+ Ok(make_duration_array(&converted, *to_unit))
}
+
(Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationSecondType>(array,
cast_options)
}
@@ -5254,6 +5286,106 @@ mod tests {
}
}
+ #[test]
+ fn test_cast_between_durations_and_numerics() {
+ fn test_cast_between_durations<FromType, ToType>()
+ where
+ FromType: ArrowPrimitiveType<Native = i64>,
+ ToType: ArrowPrimitiveType<Native = i64>,
+ PrimitiveArray<FromType>: From<Vec<Option<i64>>>,
+ {
+ let from_unit = match FromType::DATA_TYPE {
+ DataType::Duration(unit) => unit,
+ _ => panic!("Expected a duration type"),
+ };
+ let to_unit = match ToType::DATA_TYPE {
+ DataType::Duration(unit) => unit,
+ _ => panic!("Expected a duration type"),
+ };
+ let from_size = time_unit_multiple(&from_unit);
+ let to_size = time_unit_multiple(&to_unit);
+
+ let (v1_before, v2_before) = (8640003005, 1696002001);
+ let (v1_after, v2_after) = if from_size >= to_size {
+ (
+ v1_before / (from_size / to_size),
+ v2_before / (from_size / to_size),
+ )
+ } else {
+ (
+ v1_before * (to_size / from_size),
+ v2_before * (to_size / from_size),
+ )
+ };
+
+ let array =
+ PrimitiveArray::<FromType>::from(vec![Some(v1_before),
Some(v2_before), None]);
+ let b = cast(&array, &ToType::DATA_TYPE).unwrap();
+ let c = b.as_primitive::<ToType>();
+ assert_eq!(v1_after, c.value(0));
+ assert_eq!(v2_after, c.value(1));
+ assert!(c.is_null(2));
+ }
+
+ // between each individual duration type
+ test_cast_between_durations::<DurationSecondType,
DurationMillisecondType>();
+ test_cast_between_durations::<DurationSecondType,
DurationMicrosecondType>();
+ test_cast_between_durations::<DurationSecondType,
DurationNanosecondType>();
+ test_cast_between_durations::<DurationMillisecondType,
DurationSecondType>();
+ test_cast_between_durations::<DurationMillisecondType,
DurationMicrosecondType>();
+ test_cast_between_durations::<DurationMillisecondType,
DurationNanosecondType>();
+ test_cast_between_durations::<DurationMicrosecondType,
DurationSecondType>();
+ test_cast_between_durations::<DurationMicrosecondType,
DurationMillisecondType>();
+ test_cast_between_durations::<DurationMicrosecondType,
DurationNanosecondType>();
+ test_cast_between_durations::<DurationNanosecondType,
DurationSecondType>();
+ test_cast_between_durations::<DurationNanosecondType,
DurationMillisecondType>();
+ test_cast_between_durations::<DurationNanosecondType,
DurationMicrosecondType>();
+
+ // cast failed
+ let array = DurationSecondArray::from(vec![
+ Some(i64::MAX),
+ Some(8640203410378005),
+ Some(10241096),
+ None,
+ ]);
+ let b = cast(&array,
&DataType::Duration(TimeUnit::Nanosecond)).unwrap();
+ let c = b.as_primitive::<DurationNanosecondType>();
+ assert!(c.is_null(0));
+ assert!(c.is_null(1));
+ assert_eq!(10241096000000000, c.value(2));
+ assert!(c.is_null(3));
+
+ // durations to numerics
+ let array = DurationSecondArray::from(vec![
+ Some(i64::MAX),
+ Some(8640203410378005),
+ Some(10241096),
+ None,
+ ]);
+ let b = cast(&array, &DataType::Int64).unwrap();
+ let c = b.as_primitive::<Int64Type>();
+ assert_eq!(i64::MAX, c.value(0));
+ assert_eq!(8640203410378005, c.value(1));
+ assert_eq!(10241096, c.value(2));
+ assert!(c.is_null(3));
+
+ let b = cast(&array, &DataType::Int32).unwrap();
+ let c = b.as_primitive::<Int32Type>();
+ assert_eq!(0, c.value(0));
+ assert_eq!(0, c.value(1));
+ assert_eq!(10241096, c.value(2));
+ assert!(c.is_null(3));
+
+ // numerics to durations
+ let array = Int32Array::from(vec![Some(i32::MAX), Some(802034103),
Some(10241096), None]);
+ let b = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap();
+ let c = b.as_any().downcast_ref::<DurationSecondArray>().unwrap();
+ assert_eq!(i32::MAX as i64, c.value(0));
+ assert_eq!(802034103, c.value(1));
+ assert_eq!(10241096, c.value(2));
+ assert!(c.is_null(3));
+ }
+
#[test]
fn test_cast_to_strings() {
let a = Int32Array::from(vec![1, 2, 3]);