This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 09d5f02a0 Enable casting between Dictionary of DecimalArray and
DecimalArray (#3238)
09d5f02a0 is described below
commit 09d5f02a0bc76afbe7a764fc11f781f2cceaf4fb
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Dec 3 15:26:23 2022 -0800
Enable casting between Dictionary of DecimalArray and DecimalArray (#3238)
* Enable casting between Dictionary of DecimalArray and DecimalArray
* Add tests and fix more issues
* Move Dictionary matches to top
---
arrow-cast/src/cast.rs | 261 ++++++++++++++++++++++++----------------------
arrow/tests/array_cast.rs | 86 +++++++++++----
2 files changed, 206 insertions(+), 141 deletions(-)
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 8d28a6cc7..cddbf0d95 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -71,24 +71,7 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
}
match (from_type, to_type) {
- // TODO UTF8 to decimal
- // cast one decimal type to another decimal type
- (Decimal128(_, _), Decimal128(_, _)) => true,
- (Decimal256(_, _), Decimal256(_, _)) => true,
- (Decimal128(_, _), Decimal256(_, _)) => true,
- (Decimal256(_, _), Decimal128(_, _)) => true,
- // unsigned integer to decimal
- (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) |
- // signed numeric to decimal
- (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
Decimal128(_, _)) |
- (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
Decimal256(_, _)) |
- // decimal to unsigned numeric
- (Decimal128(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
- (Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
- // decimal to signed numeric
- (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 |
Float64) |
- (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64)
- | (
+ (
Null,
Boolean
| Int8
@@ -120,10 +103,12 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
| Map(_, _)
| Dictionary(_, _)
) => true,
- (Decimal128(_, _), _) => false,
- (_, Decimal128(_, _)) => false,
- (Struct(_), _) => false,
- (_, Struct(_)) => false,
+ // Dictionary/List conditions should be put in front of others
+ (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => {
+ can_cast_types(from_value_type, to_value_type)
+ }
+ (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type),
+ (_, Dictionary(_, value_type)) => can_cast_types(from_type,
value_type),
(LargeList(list_from), LargeList(list_to)) => {
can_cast_types(list_from.data_type(), list_to.data_type())
}
@@ -140,12 +125,29 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
(List(_), _) => false,
(_, List(list_to)) => can_cast_types(from_type, list_to.data_type()),
(_, LargeList(list_to)) => can_cast_types(from_type,
list_to.data_type()),
- (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => {
- can_cast_types(from_value_type, to_value_type)
- }
- (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type),
- (_, Dictionary(_, value_type)) => can_cast_types(from_type,
value_type),
-
+ // TODO UTF8 to decimal
+ // cast one decimal type to another decimal type
+ (Decimal128(_, _), Decimal128(_, _)) => true,
+ (Decimal256(_, _), Decimal256(_, _)) => true,
+ (Decimal128(_, _), Decimal256(_, _)) => true,
+ (Decimal256(_, _), Decimal128(_, _)) => true,
+ // unsigned integer to decimal
+ (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) |
+ // signed numeric to decimal
+ (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
Decimal128(_, _)) |
+ (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64,
Decimal256(_, _)) |
+ // decimal to unsigned numeric
+ (Decimal128(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
+ (Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
+ // decimal to signed numeric
+ (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 |
Float64) |
+ (Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64) => true,
+ (Decimal128(_, _), _) => false,
+ (_, Decimal128(_, _)) => false,
+ (Decimal256(_, _), _) => false,
+ (_, Decimal256(_, _)) => false,
+ (Struct(_), _) => false,
+ (_, Struct(_)) => false,
(_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8,
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
@@ -624,6 +626,103 @@ pub fn cast_with_options(
return Ok(array.clone());
}
match (from_type, to_type) {
+ (
+ Null,
+ Boolean
+ | Int8
+ | UInt8
+ | Int16
+ | UInt16
+ | Int32
+ | UInt32
+ | Float32
+ | Date32
+ | Time32(_)
+ | Int64
+ | UInt64
+ | Float64
+ | Date64
+ | Timestamp(_, _)
+ | Time64(_)
+ | Duration(_)
+ | Interval(_)
+ | FixedSizeBinary(_)
+ | Binary
+ | Utf8
+ | LargeBinary
+ | LargeUtf8
+ | List(_)
+ | LargeList(_)
+ | FixedSizeList(_, _)
+ | Struct(_)
+ | Map(_, _)
+ | Dictionary(_, _),
+ ) => Ok(new_null_array(to_type, array.len())),
+ (Dictionary(index_type, _), _) => match **index_type {
+ Int8 => dictionary_cast::<Int8Type>(array, to_type, cast_options),
+ Int16 => dictionary_cast::<Int16Type>(array, to_type,
cast_options),
+ Int32 => dictionary_cast::<Int32Type>(array, to_type,
cast_options),
+ Int64 => dictionary_cast::<Int64Type>(array, to_type,
cast_options),
+ UInt8 => dictionary_cast::<UInt8Type>(array, to_type,
cast_options),
+ UInt16 => dictionary_cast::<UInt16Type>(array, to_type,
cast_options),
+ UInt32 => dictionary_cast::<UInt32Type>(array, to_type,
cast_options),
+ UInt64 => dictionary_cast::<UInt64Type>(array, to_type,
cast_options),
+ _ => Err(ArrowError::CastError(format!(
+ "Casting from dictionary type {:?} to {:?} not supported",
+ from_type, to_type,
+ ))),
+ },
+ (_, Dictionary(index_type, value_type)) => match **index_type {
+ Int8 => cast_to_dictionary::<Int8Type>(array, value_type,
cast_options),
+ Int16 => cast_to_dictionary::<Int16Type>(array, value_type,
cast_options),
+ Int32 => cast_to_dictionary::<Int32Type>(array, value_type,
cast_options),
+ Int64 => cast_to_dictionary::<Int64Type>(array, value_type,
cast_options),
+ UInt8 => cast_to_dictionary::<UInt8Type>(array, value_type,
cast_options),
+ UInt16 => cast_to_dictionary::<UInt16Type>(array, value_type,
cast_options),
+ UInt32 => cast_to_dictionary::<UInt32Type>(array, value_type,
cast_options),
+ UInt64 => cast_to_dictionary::<UInt64Type>(array, value_type,
cast_options),
+ _ => Err(ArrowError::CastError(format!(
+ "Casting from type {:?} to dictionary type {:?} not supported",
+ from_type, to_type,
+ ))),
+ },
+ (List(_), List(ref to)) => {
+ cast_list_inner::<i32>(array, to, to_type, cast_options)
+ }
+ (LargeList(_), LargeList(ref to)) => {
+ cast_list_inner::<i64>(array, to, to_type, cast_options)
+ }
+ (List(list_from), LargeList(list_to)) => {
+ if list_to.data_type() != list_from.data_type() {
+ Err(ArrowError::CastError(
+ "cannot cast list to large-list with different child
data".into(),
+ ))
+ } else {
+ cast_list_container::<i32, i64>(&**array, cast_options)
+ }
+ }
+ (LargeList(list_from), List(list_to)) => {
+ if list_to.data_type() != list_from.data_type() {
+ Err(ArrowError::CastError(
+ "cannot cast large-list to list with different child
data".into(),
+ ))
+ } else {
+ cast_list_container::<i64, i32>(&**array, cast_options)
+ }
+ }
+ (List(_) | LargeList(_), _) => match to_type {
+ Utf8 => cast_list_to_string!(array, i32),
+ LargeUtf8 => cast_list_to_string!(array, i64),
+ _ => Err(ArrowError::CastError(
+ "Cannot cast list to non-list data types".to_string(),
+ )),
+ },
+ (_, List(ref to)) => {
+ cast_primitive_to_list::<i32>(array, to, to_type, cast_options)
+ }
+ (_, LargeList(ref to)) => {
+ cast_primitive_to_list::<i64>(array, to, to_type, cast_options)
+ }
(Decimal128(_, s1), Decimal128(p2, s2)) => {
cast_decimal_to_decimal_with_option::<16, 16>(array, s1, p2, s2,
cast_options)
}
@@ -887,107 +986,12 @@ pub fn cast_with_options(
))),
}
}
- (
- Null,
- Boolean
- | Int8
- | UInt8
- | Int16
- | UInt16
- | Int32
- | UInt32
- | Float32
- | Date32
- | Time32(_)
- | Int64
- | UInt64
- | Float64
- | Date64
- | Timestamp(_, _)
- | Time64(_)
- | Duration(_)
- | Interval(_)
- | FixedSizeBinary(_)
- | Binary
- | Utf8
- | LargeBinary
- | LargeUtf8
- | List(_)
- | LargeList(_)
- | FixedSizeList(_, _)
- | Struct(_)
- | Map(_, _)
- | Dictionary(_, _),
- ) => Ok(new_null_array(to_type, array.len())),
(Struct(_), _) => Err(ArrowError::CastError(
"Cannot cast from struct to other types".to_string(),
)),
(_, Struct(_)) => Err(ArrowError::CastError(
"Cannot cast to struct from other types".to_string(),
)),
- (List(_), List(ref to)) => {
- cast_list_inner::<i32>(array, to, to_type, cast_options)
- }
- (LargeList(_), LargeList(ref to)) => {
- cast_list_inner::<i64>(array, to, to_type, cast_options)
- }
- (List(list_from), LargeList(list_to)) => {
- if list_to.data_type() != list_from.data_type() {
- Err(ArrowError::CastError(
- "cannot cast list to large-list with different child
data".into(),
- ))
- } else {
- cast_list_container::<i32, i64>(&**array, cast_options)
- }
- }
- (LargeList(list_from), List(list_to)) => {
- if list_to.data_type() != list_from.data_type() {
- Err(ArrowError::CastError(
- "cannot cast large-list to list with different child
data".into(),
- ))
- } else {
- cast_list_container::<i64, i32>(&**array, cast_options)
- }
- }
- (List(_) | LargeList(_), Utf8) => cast_list_to_string!(array, i32),
- (List(_) | LargeList(_), LargeUtf8) => cast_list_to_string!(array,
i64),
- (List(_), _) => Err(ArrowError::CastError(
- "Cannot cast list to non-list data types".to_string(),
- )),
- (_, List(ref to)) => {
- cast_primitive_to_list::<i32>(array, to, to_type, cast_options)
- }
- (_, LargeList(ref to)) => {
- cast_primitive_to_list::<i64>(array, to, to_type, cast_options)
- }
- (Dictionary(index_type, _), _) => match **index_type {
- Int8 => dictionary_cast::<Int8Type>(array, to_type, cast_options),
- Int16 => dictionary_cast::<Int16Type>(array, to_type,
cast_options),
- Int32 => dictionary_cast::<Int32Type>(array, to_type,
cast_options),
- Int64 => dictionary_cast::<Int64Type>(array, to_type,
cast_options),
- UInt8 => dictionary_cast::<UInt8Type>(array, to_type,
cast_options),
- UInt16 => dictionary_cast::<UInt16Type>(array, to_type,
cast_options),
- UInt32 => dictionary_cast::<UInt32Type>(array, to_type,
cast_options),
- UInt64 => dictionary_cast::<UInt64Type>(array, to_type,
cast_options),
- _ => Err(ArrowError::CastError(format!(
- "Casting from dictionary type {:?} to {:?} not supported",
- from_type, to_type,
- ))),
- },
- (_, Dictionary(index_type, value_type)) => match **index_type {
- Int8 => cast_to_dictionary::<Int8Type>(array, value_type,
cast_options),
- Int16 => cast_to_dictionary::<Int16Type>(array, value_type,
cast_options),
- Int32 => cast_to_dictionary::<Int32Type>(array, value_type,
cast_options),
- Int64 => cast_to_dictionary::<Int64Type>(array, value_type,
cast_options),
- UInt8 => cast_to_dictionary::<UInt8Type>(array, value_type,
cast_options),
- UInt16 => cast_to_dictionary::<UInt16Type>(array, value_type,
cast_options),
- UInt32 => cast_to_dictionary::<UInt32Type>(array, value_type,
cast_options),
- UInt64 => cast_to_dictionary::<UInt64Type>(array, value_type,
cast_options),
- _ => Err(ArrowError::CastError(format!(
- "Casting from type {:?} to dictionary type {:?} not supported",
- from_type, to_type,
- ))),
- },
(_, Boolean) => match from_type {
UInt8 => cast_numeric_to_bool::<UInt8Type>(array),
UInt16 => cast_numeric_to_bool::<UInt16Type>(array),
@@ -3390,7 +3394,18 @@ fn cast_to_dictionary<K: ArrowDictionaryKeyType>(
dict_value_type,
cast_options,
),
+ Decimal128(_, _) => pack_numeric_to_dictionary::<K, Decimal128Type>(
+ array,
+ dict_value_type,
+ cast_options,
+ ),
+ Decimal256(_, _) => pack_numeric_to_dictionary::<K, Decimal256Type>(
+ array,
+ dict_value_type,
+ cast_options,
+ ),
Utf8 => pack_string_to_dictionary::<K>(array, cast_options),
+ LargeUtf8 => pack_string_to_dictionary::<K>(array, cast_options),
_ => Err(ArrowError::CastError(format!(
"Unsupported output type for dictionary packing: {:?}",
dict_value_type
diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs
index 95fb97328..be37a7636 100644
--- a/arrow/tests/array_cast.rs
+++ b/arrow/tests/array_cast.rs
@@ -19,12 +19,13 @@ use arrow_array::builder::{
PrimitiveDictionaryBuilder, StringDictionaryBuilder, UnionBuilder,
};
use arrow_array::types::{
- ArrowDictionaryKeyType, Int16Type, Int32Type, Int64Type, Int8Type,
- TimestampMicrosecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+ ArrowDictionaryKeyType, Decimal128Type, Decimal256Type, Int16Type,
Int32Type,
+ Int64Type, Int8Type, TimestampMicrosecondType, UInt16Type, UInt32Type,
UInt64Type,
+ UInt8Type,
};
use arrow_array::{
- Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array,
- Decimal128Array, DurationMicrosecondArray, DurationMillisecondArray,
+ Array, ArrayRef, ArrowPrimitiveType, BinaryArray, BooleanArray,
Date32Array,
+ Date64Array, Decimal128Array, DurationMicrosecondArray,
DurationMillisecondArray,
DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray,
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int16Array,
Int32Array,
Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
@@ -35,7 +36,7 @@ use arrow_array::{
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array, UnionArray,
};
-use arrow_buffer::Buffer;
+use arrow_buffer::{i256, Buffer};
use arrow_cast::{can_cast_types, cast};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, IntervalUnit, TimeUnit,
UnionMode};
@@ -101,14 +102,14 @@ fn get_arrays_of_all_types() -> Vec<ArrayRef> {
vec![
Arc::new(BinaryArray::from(binary_data.clone())),
Arc::new(LargeBinaryArray::from(binary_data.clone())),
- make_dictionary_primitive::<Int8Type>(),
- make_dictionary_primitive::<Int16Type>(),
- make_dictionary_primitive::<Int32Type>(),
- make_dictionary_primitive::<Int64Type>(),
- make_dictionary_primitive::<UInt8Type>(),
- make_dictionary_primitive::<UInt16Type>(),
- make_dictionary_primitive::<UInt32Type>(),
- make_dictionary_primitive::<UInt64Type>(),
+ make_dictionary_primitive::<Int8Type, Int32Type>(vec![1, 2]),
+ make_dictionary_primitive::<Int16Type, Int32Type>(vec![1, 2]),
+ make_dictionary_primitive::<Int32Type, Int32Type>(vec![1, 2]),
+ make_dictionary_primitive::<Int64Type, Int32Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt8Type, Int32Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt16Type, Int32Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt32Type, Int32Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt64Type, Int32Type>(vec![1, 2]),
make_dictionary_utf8::<Int8Type>(),
make_dictionary_utf8::<Int16Type>(),
make_dictionary_utf8::<Int32Type>(),
@@ -184,6 +185,46 @@ fn get_arrays_of_all_types() -> Vec<ArrayRef> {
Arc::new(
create_decimal_array(vec![Some(1), Some(2), Some(3), None], 38,
0).unwrap(),
),
+ make_dictionary_primitive::<Int8Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<Int16Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<Int32Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<Int64Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt8Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt16Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt32Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<UInt64Type, Decimal128Type>(vec![1, 2]),
+ make_dictionary_primitive::<Int8Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
+ make_dictionary_primitive::<Int16Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
+ make_dictionary_primitive::<Int32Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
+ make_dictionary_primitive::<Int64Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
+ make_dictionary_primitive::<UInt8Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
+ make_dictionary_primitive::<UInt16Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
+ make_dictionary_primitive::<UInt32Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
+ make_dictionary_primitive::<UInt64Type, Decimal256Type>(vec![
+ i256::from_i128(1),
+ i256::from_i128(2),
+ ]),
]
}
@@ -273,12 +314,15 @@ fn make_union_array() -> UnionArray {
}
/// Creates a dictionary with primitive dictionary values, and keys of type K
-fn make_dictionary_primitive<K: ArrowDictionaryKeyType>() -> ArrayRef {
+/// and values of type V
+fn make_dictionary_primitive<K: ArrowDictionaryKeyType, V: ArrowPrimitiveType>(
+ values: Vec<V::Native>,
+) -> ArrayRef {
// Pick Int32 arbitrarily for dictionary values
- let mut b: PrimitiveDictionaryBuilder<K, Int32Type> =
- PrimitiveDictionaryBuilder::new();
- b.append(1).unwrap();
- b.append(2).unwrap();
+ let mut b: PrimitiveDictionaryBuilder<K, V> =
PrimitiveDictionaryBuilder::new();
+ values.iter().for_each(|v| {
+ b.append(*v).unwrap();
+ });
Arc::new(b.finish())
}
@@ -369,6 +413,12 @@ fn get_all_types() -> Vec<DataType> {
Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)),
Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
Decimal128(38, 0),
+ Dictionary(Box::new(DataType::Int8), Box::new(Decimal128(38, 0))),
+ Dictionary(Box::new(DataType::Int16), Box::new(Decimal128(38, 0))),
+ Dictionary(Box::new(DataType::UInt32), Box::new(Decimal128(38, 0))),
+ Dictionary(Box::new(DataType::Int8), Box::new(Decimal256(76, 0))),
+ Dictionary(Box::new(DataType::Int16), Box::new(Decimal256(76, 0))),
+ Dictionary(Box::new(DataType::UInt32), Box::new(Decimal256(76, 0))),
]
}