This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 09d5f02a0 Enable casting between Dictionary of DecimalArray and 
DecimalArray (#3238)
09d5f02a0 is described below

commit 09d5f02a0bc76afbe7a764fc11f781f2cceaf4fb
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Dec 3 15:26:23 2022 -0800

    Enable casting between Dictionary of DecimalArray and DecimalArray (#3238)
    
    * Enable casting between Dictionary of DecimalArray and DecimalArray
    
    * Add tests and fix more issues
    
    * Move Dictionary matches to top
---
 arrow-cast/src/cast.rs    | 261 ++++++++++++++++++++++++----------------------
 arrow/tests/array_cast.rs |  86 +++++++++++----
 2 files changed, 206 insertions(+), 141 deletions(-)

diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 8d28a6cc7..cddbf0d95 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -71,24 +71,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
     }
 
     match (from_type, to_type) {
-        // TODO UTF8 to decimal
-        // cast one decimal type to another decimal type
-        (Decimal128(_, _), Decimal128(_, _)) => true,
-        (Decimal256(_, _), Decimal256(_, _)) => true,
-        (Decimal128(_, _), Decimal256(_, _)) => true,
-        (Decimal256(_, _), Decimal128(_, _)) => true,
-        // unsigned integer to decimal
-        (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) |
-        // signed numeric to decimal
-        (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, 
Decimal128(_, _)) |
-        (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, 
Decimal256(_, _)) |
-        // decimal to unsigned numeric
-        (Decimal128(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
-        (Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
-        // decimal to signed numeric
-        (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | 
Float64) |
-        (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64)
-        | (
+        (
             Null,
             Boolean
             | Int8
@@ -120,10 +103,12 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
             | Map(_, _)
             | Dictionary(_, _)
         ) => true,
-        (Decimal128(_, _), _) => false,
-        (_, Decimal128(_, _)) => false,
-        (Struct(_), _) => false,
-        (_, Struct(_)) => false,
+        // Dictionary/List conditions should be put in front of others
+        (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => {
+            can_cast_types(from_value_type, to_value_type)
+        }
+        (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type),
+        (_, Dictionary(_, value_type)) => can_cast_types(from_type, 
value_type),
         (LargeList(list_from), LargeList(list_to)) => {
             can_cast_types(list_from.data_type(), list_to.data_type())
         }
@@ -140,12 +125,29 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
         (List(_), _) => false,
         (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()),
         (_, LargeList(list_to)) => can_cast_types(from_type, 
list_to.data_type()),
-        (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => {
-            can_cast_types(from_value_type, to_value_type)
-        }
-        (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type),
-        (_, Dictionary(_, value_type)) => can_cast_types(from_type, 
value_type),
-
+        // TODO UTF8 to decimal
+        // cast one decimal type to another decimal type
+        (Decimal128(_, _), Decimal128(_, _)) => true,
+        (Decimal256(_, _), Decimal256(_, _)) => true,
+        (Decimal128(_, _), Decimal256(_, _)) => true,
+        (Decimal256(_, _), Decimal128(_, _)) => true,
+        // unsigned integer to decimal
+        (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) |
+        // signed numeric to decimal
+        (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, 
Decimal128(_, _)) |
+        (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, 
Decimal256(_, _)) |
+        // decimal to unsigned numeric
+        (Decimal128(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
+        (Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
+        // decimal to signed numeric
+        (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | 
Float64) |
+        (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64) => true,
+        (Decimal128(_, _), _) => false,
+        (_, Decimal128(_, _)) => false,
+        (Decimal256(_, _), _) => false,
+        (_, Decimal256(_, _)) => false,
+        (Struct(_), _) => false,
+        (_, Struct(_)) => false,
         (_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8,
         (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
 
@@ -624,6 +626,103 @@ pub fn cast_with_options(
         return Ok(array.clone());
     }
     match (from_type, to_type) {
+        (
+            Null,
+            Boolean
+            | Int8
+            | UInt8
+            | Int16
+            | UInt16
+            | Int32
+            | UInt32
+            | Float32
+            | Date32
+            | Time32(_)
+            | Int64
+            | UInt64
+            | Float64
+            | Date64
+            | Timestamp(_, _)
+            | Time64(_)
+            | Duration(_)
+            | Interval(_)
+            | FixedSizeBinary(_)
+            | Binary
+            | Utf8
+            | LargeBinary
+            | LargeUtf8
+            | List(_)
+            | LargeList(_)
+            | FixedSizeList(_, _)
+            | Struct(_)
+            | Map(_, _)
+            | Dictionary(_, _),
+        ) => Ok(new_null_array(to_type, array.len())),
+        (Dictionary(index_type, _), _) => match **index_type {
+            Int8 => dictionary_cast::<Int8Type>(array, to_type, cast_options),
+            Int16 => dictionary_cast::<Int16Type>(array, to_type, 
cast_options),
+            Int32 => dictionary_cast::<Int32Type>(array, to_type, 
cast_options),
+            Int64 => dictionary_cast::<Int64Type>(array, to_type, 
cast_options),
+            UInt8 => dictionary_cast::<UInt8Type>(array, to_type, 
cast_options),
+            UInt16 => dictionary_cast::<UInt16Type>(array, to_type, 
cast_options),
+            UInt32 => dictionary_cast::<UInt32Type>(array, to_type, 
cast_options),
+            UInt64 => dictionary_cast::<UInt64Type>(array, to_type, 
cast_options),
+            _ => Err(ArrowError::CastError(format!(
+                "Casting from dictionary type {:?} to {:?} not supported",
+                from_type, to_type,
+            ))),
+        },
+        (_, Dictionary(index_type, value_type)) => match **index_type {
+            Int8 => cast_to_dictionary::<Int8Type>(array, value_type, 
cast_options),
+            Int16 => cast_to_dictionary::<Int16Type>(array, value_type, 
cast_options),
+            Int32 => cast_to_dictionary::<Int32Type>(array, value_type, 
cast_options),
+            Int64 => cast_to_dictionary::<Int64Type>(array, value_type, 
cast_options),
+            UInt8 => cast_to_dictionary::<UInt8Type>(array, value_type, 
cast_options),
+            UInt16 => cast_to_dictionary::<UInt16Type>(array, value_type, 
cast_options),
+            UInt32 => cast_to_dictionary::<UInt32Type>(array, value_type, 
cast_options),
+            UInt64 => cast_to_dictionary::<UInt64Type>(array, value_type, 
cast_options),
+            _ => Err(ArrowError::CastError(format!(
+                "Casting from type {:?} to dictionary type {:?} not supported",
+                from_type, to_type,
+            ))),
+        },
+        (List(_), List(ref to)) => {
+            cast_list_inner::<i32>(array, to, to_type, cast_options)
+        }
+        (LargeList(_), LargeList(ref to)) => {
+            cast_list_inner::<i64>(array, to, to_type, cast_options)
+        }
+        (List(list_from), LargeList(list_to)) => {
+            if list_to.data_type() != list_from.data_type() {
+                Err(ArrowError::CastError(
+                    "cannot cast list to large-list with different child 
data".into(),
+                ))
+            } else {
+                cast_list_container::<i32, i64>(&**array, cast_options)
+            }
+        }
+        (LargeList(list_from), List(list_to)) => {
+            if list_to.data_type() != list_from.data_type() {
+                Err(ArrowError::CastError(
+                    "cannot cast large-list to list with different child 
data".into(),
+                ))
+            } else {
+                cast_list_container::<i64, i32>(&**array, cast_options)
+            }
+        }
+        (List(_) | LargeList(_), _) => match to_type {
+            Utf8 => cast_list_to_string!(array, i32),
+            LargeUtf8 => cast_list_to_string!(array, i64),
+            _ => Err(ArrowError::CastError(
+                "Cannot cast list to non-list data types".to_string(),
+            )),
+        },
+        (_, List(ref to)) => {
+            cast_primitive_to_list::<i32>(array, to, to_type, cast_options)
+        }
+        (_, LargeList(ref to)) => {
+            cast_primitive_to_list::<i64>(array, to, to_type, cast_options)
+        }
         (Decimal128(_, s1), Decimal128(p2, s2)) => {
             cast_decimal_to_decimal_with_option::<16, 16>(array, s1, p2, s2, 
cast_options)
         }
@@ -887,107 +986,12 @@ pub fn cast_with_options(
                 ))),
             }
         }
-        (
-            Null,
-            Boolean
-            | Int8
-            | UInt8
-            | Int16
-            | UInt16
-            | Int32
-            | UInt32
-            | Float32
-            | Date32
-            | Time32(_)
-            | Int64
-            | UInt64
-            | Float64
-            | Date64
-            | Timestamp(_, _)
-            | Time64(_)
-            | Duration(_)
-            | Interval(_)
-            | FixedSizeBinary(_)
-            | Binary
-            | Utf8
-            | LargeBinary
-            | LargeUtf8
-            | List(_)
-            | LargeList(_)
-            | FixedSizeList(_, _)
-            | Struct(_)
-            | Map(_, _)
-            | Dictionary(_, _),
-        ) => Ok(new_null_array(to_type, array.len())),
         (Struct(_), _) => Err(ArrowError::CastError(
             "Cannot cast from struct to other types".to_string(),
         )),
         (_, Struct(_)) => Err(ArrowError::CastError(
             "Cannot cast to struct from other types".to_string(),
         )),
-        (List(_), List(ref to)) => {
-            cast_list_inner::<i32>(array, to, to_type, cast_options)
-        }
-        (LargeList(_), LargeList(ref to)) => {
-            cast_list_inner::<i64>(array, to, to_type, cast_options)
-        }
-        (List(list_from), LargeList(list_to)) => {
-            if list_to.data_type() != list_from.data_type() {
-                Err(ArrowError::CastError(
-                    "cannot cast list to large-list with different child 
data".into(),
-                ))
-            } else {
-                cast_list_container::<i32, i64>(&**array, cast_options)
-            }
-        }
-        (LargeList(list_from), List(list_to)) => {
-            if list_to.data_type() != list_from.data_type() {
-                Err(ArrowError::CastError(
-                    "cannot cast large-list to list with different child 
data".into(),
-                ))
-            } else {
-                cast_list_container::<i64, i32>(&**array, cast_options)
-            }
-        }
-        (List(_) | LargeList(_), Utf8) => cast_list_to_string!(array, i32),
-        (List(_) | LargeList(_), LargeUtf8) => cast_list_to_string!(array, 
i64),
-        (List(_), _) => Err(ArrowError::CastError(
-            "Cannot cast list to non-list data types".to_string(),
-        )),
-        (_, List(ref to)) => {
-            cast_primitive_to_list::<i32>(array, to, to_type, cast_options)
-        }
-        (_, LargeList(ref to)) => {
-            cast_primitive_to_list::<i64>(array, to, to_type, cast_options)
-        }
-        (Dictionary(index_type, _), _) => match **index_type {
-            Int8 => dictionary_cast::<Int8Type>(array, to_type, cast_options),
-            Int16 => dictionary_cast::<Int16Type>(array, to_type, 
cast_options),
-            Int32 => dictionary_cast::<Int32Type>(array, to_type, 
cast_options),
-            Int64 => dictionary_cast::<Int64Type>(array, to_type, 
cast_options),
-            UInt8 => dictionary_cast::<UInt8Type>(array, to_type, 
cast_options),
-            UInt16 => dictionary_cast::<UInt16Type>(array, to_type, 
cast_options),
-            UInt32 => dictionary_cast::<UInt32Type>(array, to_type, 
cast_options),
-            UInt64 => dictionary_cast::<UInt64Type>(array, to_type, 
cast_options),
-            _ => Err(ArrowError::CastError(format!(
-                "Casting from dictionary type {:?} to {:?} not supported",
-                from_type, to_type,
-            ))),
-        },
-        (_, Dictionary(index_type, value_type)) => match **index_type {
-            Int8 => cast_to_dictionary::<Int8Type>(array, value_type, 
cast_options),
-            Int16 => cast_to_dictionary::<Int16Type>(array, value_type, 
cast_options),
-            Int32 => cast_to_dictionary::<Int32Type>(array, value_type, 
cast_options),
-            Int64 => cast_to_dictionary::<Int64Type>(array, value_type, 
cast_options),
-            UInt8 => cast_to_dictionary::<UInt8Type>(array, value_type, 
cast_options),
-            UInt16 => cast_to_dictionary::<UInt16Type>(array, value_type, 
cast_options),
-            UInt32 => cast_to_dictionary::<UInt32Type>(array, value_type, 
cast_options),
-            UInt64 => cast_to_dictionary::<UInt64Type>(array, value_type, 
cast_options),
-            _ => Err(ArrowError::CastError(format!(
-                "Casting from type {:?} to dictionary type {:?} not supported",
-                from_type, to_type,
-            ))),
-        },
         (_, Boolean) => match from_type {
             UInt8 => cast_numeric_to_bool::<UInt8Type>(array),
             UInt16 => cast_numeric_to_bool::<UInt16Type>(array),
@@ -3390,7 +3394,18 @@ fn cast_to_dictionary<K: ArrowDictionaryKeyType>(
             dict_value_type,
             cast_options,
         ),
+        Decimal128(_, _) => pack_numeric_to_dictionary::<K, Decimal128Type>(
+            array,
+            dict_value_type,
+            cast_options,
+        ),
+        Decimal256(_, _) => pack_numeric_to_dictionary::<K, Decimal256Type>(
+            array,
+            dict_value_type,
+            cast_options,
+        ),
         Utf8 => pack_string_to_dictionary::<K>(array, cast_options),
+        LargeUtf8 => pack_string_to_dictionary::<K>(array, cast_options),
         _ => Err(ArrowError::CastError(format!(
             "Unsupported output type for dictionary packing: {:?}",
             dict_value_type
diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs
index 95fb97328..be37a7636 100644
--- a/arrow/tests/array_cast.rs
+++ b/arrow/tests/array_cast.rs
@@ -19,12 +19,13 @@ use arrow_array::builder::{
     PrimitiveDictionaryBuilder, StringDictionaryBuilder, UnionBuilder,
 };
 use arrow_array::types::{
-    ArrowDictionaryKeyType, Int16Type, Int32Type, Int64Type, Int8Type,
-    TimestampMicrosecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+    ArrowDictionaryKeyType, Decimal128Type, Decimal256Type, Int16Type, 
Int32Type,
+    Int64Type, Int8Type, TimestampMicrosecondType, UInt16Type, UInt32Type, 
UInt64Type,
+    UInt8Type,
 };
 use arrow_array::{
-    Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array,
-    Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray,
+    Array, ArrayRef, ArrowPrimitiveType, BinaryArray, BooleanArray, 
Date32Array,
+    Date64Array, Decimal128Array, DurationMicrosecondArray, 
DurationMillisecondArray,
     DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray,
     FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int16Array, 
Int32Array,
     Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
@@ -35,7 +36,7 @@ use arrow_array::{
     TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
     UInt64Array, UInt8Array, UnionArray,
 };
-use arrow_buffer::Buffer;
+use arrow_buffer::{i256, Buffer};
 use arrow_cast::{can_cast_types, cast};
 use arrow_data::ArrayData;
 use arrow_schema::{ArrowError, DataType, Field, IntervalUnit, TimeUnit, 
UnionMode};
@@ -101,14 +102,14 @@ fn get_arrays_of_all_types() -> Vec<ArrayRef> {
     vec![
         Arc::new(BinaryArray::from(binary_data.clone())),
         Arc::new(LargeBinaryArray::from(binary_data.clone())),
-        make_dictionary_primitive::<Int8Type>(),
-        make_dictionary_primitive::<Int16Type>(),
-        make_dictionary_primitive::<Int32Type>(),
-        make_dictionary_primitive::<Int64Type>(),
-        make_dictionary_primitive::<UInt8Type>(),
-        make_dictionary_primitive::<UInt16Type>(),
-        make_dictionary_primitive::<UInt32Type>(),
-        make_dictionary_primitive::<UInt64Type>(),
+        make_dictionary_primitive::<Int8Type, Int32Type>(vec![1, 2]),
+        make_dictionary_primitive::<Int16Type, Int32Type>(vec![1, 2]),
+        make_dictionary_primitive::<Int32Type, Int32Type>(vec![1, 2]),
+        make_dictionary_primitive::<Int64Type, Int32Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt8Type, Int32Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt16Type, Int32Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt32Type, Int32Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt64Type, Int32Type>(vec![1, 2]),
         make_dictionary_utf8::<Int8Type>(),
         make_dictionary_utf8::<Int16Type>(),
         make_dictionary_utf8::<Int32Type>(),
@@ -184,6 +185,46 @@ fn get_arrays_of_all_types() -> Vec<ArrayRef> {
         Arc::new(
             create_decimal_array(vec![Some(1), Some(2), Some(3), None], 38, 
0).unwrap(),
         ),
+        make_dictionary_primitive::<Int8Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<Int16Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<Int32Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<Int64Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt8Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt16Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt32Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<UInt64Type, Decimal128Type>(vec![1, 2]),
+        make_dictionary_primitive::<Int8Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
+        make_dictionary_primitive::<Int16Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
+        make_dictionary_primitive::<Int32Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
+        make_dictionary_primitive::<Int64Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
+        make_dictionary_primitive::<UInt8Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
+        make_dictionary_primitive::<UInt16Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
+        make_dictionary_primitive::<UInt32Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
+        make_dictionary_primitive::<UInt64Type, Decimal256Type>(vec![
+            i256::from_i128(1),
+            i256::from_i128(2),
+        ]),
     ]
 }
 
@@ -273,12 +314,15 @@ fn make_union_array() -> UnionArray {
 }
 
 /// Creates a dictionary with primitive dictionary values, and keys of type K
-fn make_dictionary_primitive<K: ArrowDictionaryKeyType>() -> ArrayRef {
+/// and values of type V
+fn make_dictionary_primitive<K: ArrowDictionaryKeyType, V: ArrowPrimitiveType>(
+    values: Vec<V::Native>,
+) -> ArrayRef {
     // Pick Int32 arbitrarily for dictionary values
-    let mut b: PrimitiveDictionaryBuilder<K, Int32Type> =
-        PrimitiveDictionaryBuilder::new();
-    b.append(1).unwrap();
-    b.append(2).unwrap();
+    let mut b: PrimitiveDictionaryBuilder<K, V> = 
PrimitiveDictionaryBuilder::new();
+    values.iter().for_each(|v| {
+        b.append(*v).unwrap();
+    });
     Arc::new(b.finish())
 }
 
@@ -369,6 +413,12 @@ fn get_all_types() -> Vec<DataType> {
         Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)),
         Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
         Decimal128(38, 0),
+        Dictionary(Box::new(DataType::Int8), Box::new(Decimal128(38, 0))),
+        Dictionary(Box::new(DataType::Int16), Box::new(Decimal128(38, 0))),
+        Dictionary(Box::new(DataType::UInt32), Box::new(Decimal128(38, 0))),
+        Dictionary(Box::new(DataType::Int8), Box::new(Decimal256(76, 0))),
+        Dictionary(Box::new(DataType::Int16), Box::new(Decimal256(76, 0))),
+        Dictionary(Box::new(DataType::UInt32), Box::new(Decimal256(76, 0))),
     ]
 }
 

Reply via email to