This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f0fe6b7e refactor: simplify can_cast_types code. (#4185)
9f0fe6b7e is described below

commit 9f0fe6b7ecb3c8bd7591f2ccd78dd92b50563988
Author: jakevin <[email protected]>
AuthorDate: Wed May 10 15:54:24 2023 +0800

    refactor: simplify can_cast_types code. (#4185)
---
 arrow-cast/src/cast.rs | 184 +++++++++++++++++--------------------------------
 1 file changed, 62 insertions(+), 122 deletions(-)

diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index d015f4952..37fede0a6 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -78,6 +78,8 @@ impl<'a> Default for CastOptions<'a> {
 /// If this function returns true to stay consistent with the `cast` kernel 
below.
 pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
     use self::DataType::*;
+    use self::IntervalUnit::*;
+    use self::TimeUnit::*;
     if from_type == to_type {
         return true;
     }
@@ -113,7 +115,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
             | FixedSizeList(_, _)
             | Struct(_)
             | Map(_, _)
-            | Dictionary(_, _)
+            | Dictionary(_, _),
         ) => true,
         // Dictionary/List conditions should be put in front of others
         (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => {
@@ -133,7 +135,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
         (LargeList(list_from), List(list_to)) => {
             list_from.data_type() == list_to.data_type()
         }
-        (List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) => 
can_cast_types(list_from.data_type(), to_type),
+        (List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) => {
+            can_cast_types(list_from.data_type(), to_type)
+        }
         (List(_), _) => false,
         (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()),
         (_, LargeList(list_to)) => can_cast_types(from_type, 
list_to.data_type()),
@@ -149,114 +153,54 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
         (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, 
Decimal128(_, _)) |
         (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, 
Decimal256(_, _)) |
         // decimal to unsigned numeric
-        (Decimal128(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
-        (Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
+        (Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | 
UInt64) |
         // decimal to signed numeric
-        (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | 
Float64) |
-        (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | 
Float64) => true,
+        (Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | 
Int64 | Float32 | Float64) => true,
         // decimal to Utf8
-        (Decimal128(_, _), Utf8 | LargeUtf8) => true,
-        (Decimal256(_, _), Utf8 | LargeUtf8) => true,
+        (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true,
         // Utf8 to decimal
-        (Utf8 | LargeUtf8, Decimal128(_, _)) => true,
-        (Utf8 | LargeUtf8, Decimal256(_, _)) => true,
-        (Decimal128(_, _), _) => false,
-        (_, Decimal128(_, _)) => false,
-        (Decimal256(_, _), _) => false,
-        (_, Decimal256(_, _)) => false,
+        (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
+        (Decimal128(_, _) | Decimal256(_, _), _) => false,
+        (_, Decimal128(_, _) | Decimal256(_, _)) => false,
         (Struct(_), _) => false,
         (_, Struct(_)) => false,
-        (_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8 
|| from_type == &LargeUtf8,
-        (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8 || 
to_type == &LargeUtf8,
+        (_, Boolean) => {
+            DataType::is_numeric(from_type)
+                || from_type == &Utf8
+                || from_type == &LargeUtf8
+        }
+        (Boolean, _) => {
+            DataType::is_numeric(to_type) || to_type == &Utf8 || to_type == 
&LargeUtf8
+        }
 
         (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_)) => true,
         (LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_)) => true,
         (FixedSizeBinary(_), Binary | LargeBinary) => true,
-        (Utf8,
-            Binary
-            | LargeBinary
-            | LargeUtf8
-            | Date32
-            | Date64
-            | Time32(TimeUnit::Second)
-            | Time32(TimeUnit::Millisecond)
-            | Time64(TimeUnit::Microsecond)
-            | Time64(TimeUnit::Nanosecond)
-            | Timestamp(TimeUnit::Second, _)
-            | Timestamp(TimeUnit::Millisecond, _)
-            | Timestamp(TimeUnit::Microsecond, _)
-            | Timestamp(TimeUnit::Nanosecond, _)
-            | Interval(_)
-        ) => true,
-        (Utf8, _) => to_type.is_numeric() && to_type != &Float16,
-        (LargeUtf8,
+        (
+            Utf8 | LargeUtf8,
             Binary
             | LargeBinary
             | Utf8
+            | LargeUtf8
             | Date32
             | Date64
-            | Time32(TimeUnit::Second)
-            | Time32(TimeUnit::Millisecond)
-            | Time64(TimeUnit::Microsecond)
-            | Time64(TimeUnit::Nanosecond)
-            | Timestamp(TimeUnit::Second, _)
-            | Timestamp(TimeUnit::Millisecond, _)
-            | Timestamp(TimeUnit::Microsecond, _)
-            | Timestamp(TimeUnit::Nanosecond, _)
-            | Interval(_)
+            | Time32(Second)
+            | Time32(Millisecond)
+            | Time64(Microsecond)
+            | Time64(Nanosecond)
+            | Timestamp(Second, _)
+            | Timestamp(Millisecond, _)
+            | Timestamp(Microsecond, _)
+            | Timestamp(Nanosecond, _)
+            | Interval(_),
         ) => true,
-        (LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
+        (Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
         (_, Utf8 | LargeUtf8) => from_type.is_primitive(),
 
         // start numeric casts
         (
-            UInt8,
-            UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 
| Float64,
-        ) => true,
-
-        (
-            UInt16,
-            UInt8 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | 
Float64,
-        ) => true,
-
-        (
-            UInt32,
-            UInt8 | UInt16 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | 
Float64,
-        ) => true,
-
-        (
-            UInt64,
-            UInt8 | UInt16 | UInt32 | Int8 | Int16 | Int32 | Int64 | Float32 | 
Float64,
-        ) => true,
-
-        (
-            Int8,
-            UInt8 | UInt16 | UInt32 | UInt64 | Int16 | Int32 | Int64 | Float32 
| Float64,
-        ) => true,
-
-        (
-            Int16,
-            UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int32 | Int64 | Float32 
| Float64,
-        ) => true,
-
-        (
-            Int32,
-            UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int64 | Float32 
| Float64,
-        ) => true,
-
-        (
-            Int64,
-            UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Float32 
| Float64,
-        ) => true,
-
-        (
-            Float32,
-            UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | 
Float64,
-        ) => true,
-
-        (
-            Float64,
-            UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | 
Float32,
+            UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | 
Float32 | Float64,
+            UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | 
Float32 | Float64,
         ) => true,
         // end numeric casts
 
@@ -267,53 +211,49 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
         (Int64, Date64 | Date32 | Time64(_)) => true,
         (Date64, Int64 | Int32) => true,
         (Time64(_), Int64) => true,
-        (Date32, Date64) => true,
-        (Date64, Date32) => true,
-        (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true,
-        (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true,
+        (Date32 | Date64, Date32 | Date64) => true,
+        // time casts
+        (Time32(_), Time32(_)) => true,
         (Time32(_), Time64(_)) => true,
-        (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true,
-        (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true,
+        (Time64(_), Time64(_)) => true,
         (Time64(_), Time32(to_unit)) => {
-            matches!(to_unit, TimeUnit::Second | TimeUnit::Millisecond)
+            matches!(to_unit, Second | Millisecond)
         }
         (Timestamp(_, _), Int64) => true,
         (Int64, Timestamp(_, _)) => true,
         (Date64, Timestamp(_, None)) => true,
         (Date32, Timestamp(_, None)) => true,
-        (Timestamp(_, _),
+        (
+            Timestamp(_, _),
             Timestamp(_, _)
             | Date32
             | Date64
-            | Time32(TimeUnit::Second)
-            | Time32(TimeUnit::Millisecond)
-            | Time64(TimeUnit::Microsecond)
-            | Time64(TimeUnit::Nanosecond)) => true,
+            | Time32(Second)
+            | Time32(Millisecond)
+            | Time64(Microsecond)
+            | Time64(Nanosecond),
+        ) => true,
         (Int64, Duration(_)) => true,
         (Duration(_), Int64) => true,
         (Interval(from_type), Int64) => {
             match from_type {
-                IntervalUnit::YearMonth => true,
-                IntervalUnit::DayTime => true,
-                IntervalUnit::MonthDayNano => false, // Native type is i128
-            }
-        }
-        (Int32, Interval(to_type)) => {
-            match to_type {
-                IntervalUnit::YearMonth => true,
-                IntervalUnit::DayTime => false,
-                IntervalUnit::MonthDayNano => false,
-            }
-        }
-        (Int64, Interval(to_type)) => {
-            match to_type {
-                IntervalUnit::YearMonth => false,
-                IntervalUnit::DayTime => true,
-                IntervalUnit::MonthDayNano => false,
+                YearMonth => true,
+                DayTime => true,
+                MonthDayNano => false, // Native type is i128
             }
         }
-        (Duration(_), Interval(IntervalUnit::MonthDayNano)) => true,
-        (Interval(IntervalUnit::MonthDayNano), Duration(_)) => true,
+        (Int32, Interval(to_type)) => match to_type {
+            YearMonth => true,
+            DayTime => false,
+            MonthDayNano => false,
+        },
+        (Int64, Interval(to_type)) => match to_type {
+            YearMonth => false,
+            DayTime => true,
+            MonthDayNano => false,
+        },
+        (Duration(_), Interval(MonthDayNano)) => true,
+        (Interval(MonthDayNano), Duration(_)) => true,
         (_, _) => false,
     }
 }

Reply via email to