This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 9f0fe6b7e refactor: simplify can_cast_types code. (#4185)
9f0fe6b7e is described below
commit 9f0fe6b7ecb3c8bd7591f2ccd78dd92b50563988
Author: jakevin <[email protected]>
AuthorDate: Wed May 10 15:54:24 2023 +0800
refactor: simplify can_cast_types code. (#4185)
---
arrow-cast/src/cast.rs | 184 +++++++++++++++++--------------------------------
1 file changed, 62 insertions(+), 122 deletions(-)
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index d015f4952..37fede0a6 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -78,6 +78,8 @@ impl<'a> Default for CastOptions<'a> {
/// If this function returns true to stay consistent with the `cast` kernel
below.
pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
use self::DataType::*;
+ use self::IntervalUnit::*;
+ use self::TimeUnit::*;
if from_type == to_type {
return true;
}
@@ -113,7 +115,7 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
| FixedSizeList(_, _)
| Struct(_)
| Map(_, _)
- | Dictionary(_, _)
+ | Dictionary(_, _),
) => true,
// Dictionary/List conditions should be put in front of others
(Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => {
@@ -133,7 +135,9 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
(LargeList(list_from), List(list_to)) => {
list_from.data_type() == list_to.data_type()
}
- (List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) =>
can_cast_types(list_from.data_type(), to_type),
+ (List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) => {
+ can_cast_types(list_from.data_type(), to_type)
+ }
(List(_), _) => false,
(_, List(list_to)) => can_cast_types(from_type, list_to.data_type()),
(_, LargeList(list_to)) => can_cast_types(from_type,
list_to.data_type()),
@@ -149,114 +153,54 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
(Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
Decimal128(_, _)) |
(Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
Decimal256(_, _)) |
// decimal to unsigned numeric
- (Decimal128(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
- (Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
+ (Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 |
UInt64) |
// decimal to signed numeric
- (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 |
Float64) |
- (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 |
Float64) => true,
+ (Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 |
Int64 | Float32 | Float64) => true,
// decimal to Utf8
- (Decimal128(_, _), Utf8 | LargeUtf8) => true,
- (Decimal256(_, _), Utf8 | LargeUtf8) => true,
+ (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true,
// Utf8 to decimal
- (Utf8 | LargeUtf8, Decimal128(_, _)) => true,
- (Utf8 | LargeUtf8, Decimal256(_, _)) => true,
- (Decimal128(_, _), _) => false,
- (_, Decimal128(_, _)) => false,
- (Decimal256(_, _), _) => false,
- (_, Decimal256(_, _)) => false,
+ (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
+ (Decimal128(_, _) | Decimal256(_, _), _) => false,
+ (_, Decimal128(_, _) | Decimal256(_, _)) => false,
(Struct(_), _) => false,
(_, Struct(_)) => false,
- (_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8
|| from_type == &LargeUtf8,
- (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8 ||
to_type == &LargeUtf8,
+ (_, Boolean) => {
+ DataType::is_numeric(from_type)
+ || from_type == &Utf8
+ || from_type == &LargeUtf8
+ }
+ (Boolean, _) => {
+ DataType::is_numeric(to_type) || to_type == &Utf8 || to_type ==
&LargeUtf8
+ }
(Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_)) => true,
(LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_)) => true,
(FixedSizeBinary(_), Binary | LargeBinary) => true,
- (Utf8,
- Binary
- | LargeBinary
- | LargeUtf8
- | Date32
- | Date64
- | Time32(TimeUnit::Second)
- | Time32(TimeUnit::Millisecond)
- | Time64(TimeUnit::Microsecond)
- | Time64(TimeUnit::Nanosecond)
- | Timestamp(TimeUnit::Second, _)
- | Timestamp(TimeUnit::Millisecond, _)
- | Timestamp(TimeUnit::Microsecond, _)
- | Timestamp(TimeUnit::Nanosecond, _)
- | Interval(_)
- ) => true,
- (Utf8, _) => to_type.is_numeric() && to_type != &Float16,
- (LargeUtf8,
+ (
+ Utf8 | LargeUtf8,
Binary
| LargeBinary
| Utf8
+ | LargeUtf8
| Date32
| Date64
- | Time32(TimeUnit::Second)
- | Time32(TimeUnit::Millisecond)
- | Time64(TimeUnit::Microsecond)
- | Time64(TimeUnit::Nanosecond)
- | Timestamp(TimeUnit::Second, _)
- | Timestamp(TimeUnit::Millisecond, _)
- | Timestamp(TimeUnit::Microsecond, _)
- | Timestamp(TimeUnit::Nanosecond, _)
- | Interval(_)
+ | Time32(Second)
+ | Time32(Millisecond)
+ | Time64(Microsecond)
+ | Time64(Nanosecond)
+ | Timestamp(Second, _)
+ | Timestamp(Millisecond, _)
+ | Timestamp(Microsecond, _)
+ | Timestamp(Nanosecond, _)
+ | Interval(_),
) => true,
- (LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
+ (Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
(_, Utf8 | LargeUtf8) => from_type.is_primitive(),
// start numeric casts
(
- UInt8,
- UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32
| Float64,
- ) => true,
-
- (
- UInt16,
- UInt8 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 |
Float64,
- ) => true,
-
- (
- UInt32,
- UInt8 | UInt16 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 |
Float64,
- ) => true,
-
- (
- UInt64,
- UInt8 | UInt16 | UInt32 | Int8 | Int16 | Int32 | Int64 | Float32 |
Float64,
- ) => true,
-
- (
- Int8,
- UInt8 | UInt16 | UInt32 | UInt64 | Int16 | Int32 | Int64 | Float32
| Float64,
- ) => true,
-
- (
- Int16,
- UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int32 | Int64 | Float32
| Float64,
- ) => true,
-
- (
- Int32,
- UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int64 | Float32
| Float64,
- ) => true,
-
- (
- Int64,
- UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Float32
| Float64,
- ) => true,
-
- (
- Float32,
- UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 |
Float64,
- ) => true,
-
- (
- Float64,
- UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 |
Float32,
+ UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 |
Float32 | Float64,
+ UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 |
Float32 | Float64,
) => true,
// end numeric casts
@@ -267,53 +211,49 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
(Int64, Date64 | Date32 | Time64(_)) => true,
(Date64, Int64 | Int32) => true,
(Time64(_), Int64) => true,
- (Date32, Date64) => true,
- (Date64, Date32) => true,
- (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true,
- (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true,
+ (Date32 | Date64, Date32 | Date64) => true,
+ // time casts
+ (Time32(_), Time32(_)) => true,
(Time32(_), Time64(_)) => true,
- (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true,
- (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true,
+ (Time64(_), Time64(_)) => true,
(Time64(_), Time32(to_unit)) => {
- matches!(to_unit, TimeUnit::Second | TimeUnit::Millisecond)
+ matches!(to_unit, Second | Millisecond)
}
(Timestamp(_, _), Int64) => true,
(Int64, Timestamp(_, _)) => true,
(Date64, Timestamp(_, None)) => true,
(Date32, Timestamp(_, None)) => true,
- (Timestamp(_, _),
+ (
+ Timestamp(_, _),
Timestamp(_, _)
| Date32
| Date64
- | Time32(TimeUnit::Second)
- | Time32(TimeUnit::Millisecond)
- | Time64(TimeUnit::Microsecond)
- | Time64(TimeUnit::Nanosecond)) => true,
+ | Time32(Second)
+ | Time32(Millisecond)
+ | Time64(Microsecond)
+ | Time64(Nanosecond),
+ ) => true,
(Int64, Duration(_)) => true,
(Duration(_), Int64) => true,
(Interval(from_type), Int64) => {
match from_type {
- IntervalUnit::YearMonth => true,
- IntervalUnit::DayTime => true,
- IntervalUnit::MonthDayNano => false, // Native type is i128
- }
- }
- (Int32, Interval(to_type)) => {
- match to_type {
- IntervalUnit::YearMonth => true,
- IntervalUnit::DayTime => false,
- IntervalUnit::MonthDayNano => false,
- }
- }
- (Int64, Interval(to_type)) => {
- match to_type {
- IntervalUnit::YearMonth => false,
- IntervalUnit::DayTime => true,
- IntervalUnit::MonthDayNano => false,
+ YearMonth => true,
+ DayTime => true,
+ MonthDayNano => false, // Native type is i128
}
}
- (Duration(_), Interval(IntervalUnit::MonthDayNano)) => true,
- (Interval(IntervalUnit::MonthDayNano), Duration(_)) => true,
+ (Int32, Interval(to_type)) => match to_type {
+ YearMonth => true,
+ DayTime => false,
+ MonthDayNano => false,
+ },
+ (Int64, Interval(to_type)) => match to_type {
+ YearMonth => false,
+ DayTime => true,
+ MonthDayNano => false,
+ },
+ (Duration(_), Interval(MonthDayNano)) => true,
+ (Interval(MonthDayNano), Duration(_)) => true,
(_, _) => false,
}
}