This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 36d31eb ARROW-11481: [Rust] More cast implementations
36d31eb is described below
commit 36d31eb5b0758a253ef231fdbae3275b9b7360ef
Author: Ritchie Vink <[email protected]>
AuthorDate: Tue Feb 16 14:57:54 2021 -0500
ARROW-11481: [Rust] More cast implementations
This PR adds some more cast implementations that I missed.
Included casts:
```
* LargeList -> List (of same child datatype) (if not will return Error)
* List -> LargeList (of same child datatype) (if not will return Error)
* LargeList<A> -> LargeList<B>
* Int32 -> Date64
* Date32 -> Int64
* Int64 -> Date32 (lossy)
* Date64 -> Int32 (lossy)
```
Closes #9402 from ritchie46/more_casts
Authored-by: Ritchie Vink <[email protected]>
Signed-off-by: Andrew Lamb <[email protected]>
---
rust/arrow/src/array/data.rs | 2 +-
rust/arrow/src/compute/kernels/cast.rs | 294 ++++++++++++++++++++++++++++-----
2 files changed, 252 insertions(+), 44 deletions(-)
diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs
index c118515..0a10e9f 100644
--- a/rust/arrow/src/array/data.rs
+++ b/rust/arrow/src/array/data.rs
@@ -402,7 +402,7 @@ impl ArrayData {
/// * the buffer is not byte-aligned with type T, or
/// * the datatype is `Boolean` (it corresponds to a bit-packed buffer
where the offset is not applicable)
#[inline]
- pub(super) fn buffer<T: ArrowNativeType>(&self, buffer: usize) -> &[T] {
+ pub(crate) fn buffer<T: ArrowNativeType>(&self, buffer: usize) -> &[T] {
let values = unsafe { self.buffers[buffer].as_slice().align_to::<T>()
};
if !values.0.is_empty() || !values.2.is_empty() {
panic!("The buffer is not byte-aligned with its interpretation")
diff --git a/rust/arrow/src/compute/kernels/cast.rs
b/rust/arrow/src/compute/kernels/cast.rs
index d487479..9a60757 100644
--- a/rust/arrow/src/compute/kernels/cast.rs
+++ b/rust/arrow/src/compute/kernels/cast.rs
@@ -38,6 +38,7 @@
use std::str;
use std::sync::Arc;
+use crate::buffer::MutableBuffer;
use crate::compute::kernels::arithmetic::{divide, multiply};
use crate::compute::kernels::arity::unary;
use crate::compute::kernels::cast_utils::string_to_timestamp_nanos;
@@ -45,6 +46,7 @@ use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::{array::*, compute::take};
use crate::{buffer::Buffer, util::serialization::lexical_to_string};
+use num::{NumCast, ToPrimitive};
/// Return true if a value of type `from_type` can be cast into a
/// value of `to_type`. Note that such as cast may be lossy.
@@ -59,11 +61,18 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
match (from_type, to_type) {
(Struct(_), _) => false,
(_, Struct(_)) => false,
+ (LargeList(list_from), LargeList(list_to)) => {
+ can_cast_types(list_from.data_type(), list_to.data_type())
+ }
(List(list_from), List(list_to)) => {
can_cast_types(list_from.data_type(), list_to.data_type())
}
+ (List(list_from), LargeList(list_to)) => {
+ list_from.data_type() == list_to.data_type()
+ }
(List(_), _) => false,
(_, List(list_to)) => can_cast_types(from_type, list_to.data_type()),
+ (_, LargeList(list_to)) => can_cast_types(from_type,
list_to.data_type()),
(Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => {
can_cast_types(from_value_type, to_value_type)
}
@@ -183,12 +192,16 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
// temporal casts
(Int32, Date32) => true,
+ (Int32, Date64) => true,
(Int32, Time32(_)) => true,
(Date32, Int32) => true,
+ (Date32, Int64) => true,
(Time32(_), Int32) => true,
(Int64, Date64) => true,
+ (Int64, Date32) => true,
(Int64, Time64(_)) => true,
(Date64, Int64) => true,
+ (Date64, Int32) => true,
(Time64(_), Int64) => true,
(Date32, Date64) => true,
(Date64, Date32) => true,
@@ -247,53 +260,31 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) ->
Result<ArrayRef> {
(_, Struct(_)) => Err(ArrowError::ComputeError(
"Cannot cast to struct from other types".to_string(),
)),
- (List(_), List(ref to)) => {
- let data = array.data_ref();
- let underlying_array = make_array(data.child_data()[0].clone());
- let cast_array = cast(&underlying_array, to.data_type())?;
- let array_data = ArrayData::new(
- to.data_type().clone(),
- array.len(),
- Some(cast_array.null_count()),
- cast_array
- .data()
- .null_bitmap()
- .clone()
- .map(|bitmap| bitmap.bits),
- array.offset(),
- // reuse offset buffer
- data.buffers().to_vec(),
- vec![cast_array.data()],
- );
- let list = ListArray::from(Arc::new(array_data));
- Ok(Arc::new(list) as ArrayRef)
+ (List(_), List(ref to)) => cast_list_inner::<i32>(&**array, to),
+ (LargeList(_), LargeList(ref to)) => cast_list_inner::<i64>(&**array,
to),
+ (List(list_from), LargeList(list_to)) => {
+ if list_to.data_type() != list_from.data_type() {
+ Err(ArrowError::ComputeError(
+ "cannot cast list to large-list with different child
data".into(),
+ ))
+ } else {
+ cast_list_container::<i32, i64>(&**array)
+ }
+ }
+ (LargeList(list_from), List(list_to)) => {
+ if list_to.data_type() != list_from.data_type() {
+ Err(ArrowError::ComputeError(
+ "cannot cast large-list to list with different child
data".into(),
+ ))
+ } else {
+ cast_list_container::<i64, i32>(&**array)
+ }
}
(List(_), _) => Err(ArrowError::ComputeError(
"Cannot cast list to non-list data types".to_string(),
)),
- (_, List(ref to)) => {
- // cast primitive to list's primitive
- let cast_array = cast(array, to.data_type())?;
- // create offsets, where if array.len() = 2, we have [0,1,2]
- let offsets: Vec<i32> = (0..=array.len() as i32).collect();
- let value_offsets = Buffer::from_slice_ref(&offsets);
- let list_data = ArrayData::new(
- to.data_type().clone(),
- array.len(),
- Some(cast_array.null_count()),
- cast_array
- .data()
- .null_bitmap()
- .clone()
- .map(|bitmap| bitmap.bits),
- 0,
- vec![value_offsets],
- vec![cast_array.data()],
- );
- let list_array = Arc::new(ListArray::from(Arc::new(list_data))) as
ArrayRef;
-
- Ok(list_array)
- }
+ (_, List(ref to)) => cast_primitive_to_list::<i32>(array, to),
+ (_, LargeList(ref to)) => cast_primitive_to_list::<i64>(array, to),
(Dictionary(index_type, _), _) => match **index_type {
DataType::Int8 => dictionary_cast::<Int8Type>(array, to_type),
DataType::Int16 => dictionary_cast::<Int16Type>(array, to_type),
@@ -566,6 +557,7 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) ->
Result<ArrayRef> {
// temporal casts
(Int32, Date32) => cast_array_data::<Date32Type>(array,
to_type.clone()),
+ (Int32, Date64) => cast(&cast(array, &DataType::Date32)?,
&DataType::Date64),
(Int32, Time32(TimeUnit::Second)) => {
cast_array_data::<Time32SecondType>(array, to_type.clone())
}
@@ -574,8 +566,10 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) ->
Result<ArrayRef> {
}
// No support for microsecond/nanosecond with i32
(Date32, Int32) => cast_array_data::<Int32Type>(array,
to_type.clone()),
+ (Date32, Int64) => cast(&cast(array, &DataType::Int32)?,
&DataType::Int64),
(Time32(_), Int32) => cast_array_data::<Int32Type>(array,
to_type.clone()),
(Int64, Date64) => cast_array_data::<Date64Type>(array,
to_type.clone()),
+ (Int64, Date32) => cast(&cast(array, &DataType::Int32)?,
&DataType::Date32),
// No support for second/milliseconds with i64
(Int64, Time64(TimeUnit::Microsecond)) => {
cast_array_data::<Time64MicrosecondType>(array, to_type.clone())
@@ -585,6 +579,7 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) ->
Result<ArrayRef> {
}
(Date64, Int64) => cast_array_data::<Int64Type>(array,
to_type.clone()),
+ (Date64, Int32) => cast(&cast(array, &DataType::Int64)?,
&DataType::Int32),
(Time64(_), Int64) => cast_array_data::<Int64Type>(array,
to_type.clone()),
(Date32, Date64) => {
let date_array =
array.as_any().downcast_ref::<Date32Array>().unwrap();
@@ -1207,6 +1202,137 @@ where
Ok(Arc::new(b.finish()))
}
+/// Helper function that takes a primitive array and casts to a (generic) list
array.
+fn cast_primitive_to_list<OffsetSize: OffsetSizeTrait + NumCast>(
+ array: &ArrayRef,
+ to: &Field,
+) -> Result<ArrayRef> {
+ // cast primitive to list's primitive
+ let cast_array = cast(array, to.data_type())?;
+ // create offsets, where if array.len() = 2, we have [0,1,2]
+ // Safety:
+ // Length of range can be trusted.
+ // Note: could not yet create a generic range in stable Rust.
+ let offsets = unsafe {
+ MutableBuffer::from_trusted_len_iter(
+ (0..=array.len()).map(|i| OffsetSize::from(i).expect("integer")),
+ )
+ };
+
+ let list_data = ArrayData::new(
+ to.data_type().clone(),
+ array.len(),
+ Some(cast_array.null_count()),
+ cast_array
+ .data()
+ .null_bitmap()
+ .clone()
+ .map(|bitmap| bitmap.bits),
+ 0,
+ vec![offsets.into()],
+ vec![cast_array.data()],
+ );
+ let list_array =
+ Arc::new(GenericListArray::<OffsetSize>::from(Arc::new(list_data))) as
ArrayRef;
+
+ Ok(list_array)
+}
+
+/// Helper function that takes an Generic list container and casts the inner
datatype.
+fn cast_list_inner<OffsetSize: OffsetSizeTrait>(
+ array: &dyn Array,
+ to: &Field,
+) -> Result<ArrayRef> {
+ let data = array.data_ref();
+ let underlying_array = make_array(data.child_data()[0].clone());
+ let cast_array = cast(&underlying_array, to.data_type())?;
+ let array_data = ArrayData::new(
+ to.data_type().clone(),
+ array.len(),
+ Some(cast_array.null_count()),
+ cast_array
+ .data()
+ .null_bitmap()
+ .clone()
+ .map(|bitmap| bitmap.bits),
+ array.offset(),
+ // reuse offset buffer
+ data.buffers().to_vec(),
+ vec![cast_array.data()],
+ );
+ let list = GenericListArray::<OffsetSize>::from(Arc::new(array_data));
+ Ok(Arc::new(list) as ArrayRef)
+}
+
+/// Cast the container type of List/Largelist array but not the inner types.
+/// This function can leave the value data intact and only has to cast the
offset dtypes.
+fn cast_list_container<OffsetSizeFrom, OffsetSizeTo>(
+ array: &dyn Array,
+) -> Result<ArrayRef>
+where
+ OffsetSizeFrom: OffsetSizeTrait + ToPrimitive,
+ OffsetSizeTo: OffsetSizeTrait + NumCast,
+{
+ let data = array.data_ref();
+ // the value data stored by the list
+ let value_data = data.child_data()[0].clone();
+
+ let out_dtype = match array.data_type() {
+ DataType::List(value_type) => {
+ assert_eq!(
+ std::mem::size_of::<OffsetSizeFrom>(),
+ std::mem::size_of::<i32>()
+ );
+ assert_eq!(
+ std::mem::size_of::<OffsetSizeTo>(),
+ std::mem::size_of::<i64>()
+ );
+ DataType::LargeList(value_type.clone())
+ }
+ DataType::LargeList(value_type) => {
+ assert_eq!(
+ std::mem::size_of::<OffsetSizeFrom>(),
+ std::mem::size_of::<i64>()
+ );
+ assert_eq!(
+ std::mem::size_of::<OffsetSizeTo>(),
+ std::mem::size_of::<i32>()
+ );
+ if value_data.len() > i32::MAX as usize {
+ return Err(ArrowError::ComputeError(
+ "LargeList too large to cast to List".into(),
+ ));
+ }
+ DataType::List(value_type.clone())
+ }
+ // implementation error
+ _ => unreachable!(),
+ };
+
+ let offsets = data.buffer::<OffsetSizeFrom>(0);
+
+ let iter = offsets.iter().map(|idx| {
+ let idx: OffsetSizeTo = NumCast::from(*idx).unwrap();
+ idx
+ });
+
+ // SAFETY
+ // A slice produces a trusted length iterator
+ let offset_buffer = unsafe { Buffer::from_trusted_len_iter(iter) };
+
+ // wrap up
+ let mut builder = ArrayData::builder(out_dtype)
+ .len(array.len())
+ .add_buffer(offset_buffer)
+ .add_child_data(value_data);
+
+ if let Some(buf) = data.null_buffer() {
+ builder = builder.null_bit_buffer(buf.clone())
+ }
+ let data = builder.build();
+ Ok(make_array(data))
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -2245,6 +2371,11 @@ mod tests {
get_cast_values::<Int32Type>(&i64_array, &DataType::Int32)
);
+ assert_eq!(
+ i32_expected,
+ get_cast_values::<Date32Type>(&i64_array, &DataType::Date32)
+ );
+
let i16_expected = vec![
"null", "null", "-32768", "-128", "0", "127", "32767", "null",
"null",
];
@@ -2388,6 +2519,21 @@ mod tests {
u8_expected,
get_cast_values::<UInt8Type>(&i32_array, &DataType::UInt8)
);
+
+ // The date32 to date64 cast increases the numerical values in order
to keep the same dates.
+ let i64_expected = vec![
+ "-185542587187200000",
+ "-2831155200000",
+ "-11059200000",
+ "0",
+ "10972800000",
+ "2831068800000",
+ "185542587100800000",
+ ];
+ assert_eq!(
+ i64_expected,
+ get_cast_values::<Date64Type>(&i32_array, &DataType::Date64)
+ );
}
#[test]
@@ -2463,6 +2609,34 @@ mod tests {
}
#[test]
+ fn test_cast_from_date32() {
+ let i32_values: Vec<i32> = vec![
+ std::i32::MIN as i32,
+ std::i16::MIN as i32,
+ std::i8::MIN as i32,
+ 0,
+ std::i8::MAX as i32,
+ std::i16::MAX as i32,
+ std::i32::MAX as i32,
+ ];
+ let date32_array: ArrayRef = Arc::new(Date32Array::from(i32_values));
+
+ let i64_expected = vec![
+ "-2147483648",
+ "-32768",
+ "-128",
+ "0",
+ "127",
+ "32767",
+ "2147483647",
+ ];
+ assert_eq!(
+ i64_expected,
+ get_cast_values::<Int64Type>(&date32_array, &DataType::Int64)
+ );
+ }
+
+ #[test]
fn test_cast_from_int8() {
let i8_values: Vec<i8> = vec![std::i8::MIN, 0, std::i8::MAX];
let i8_array: ArrayRef = Arc::new(Int8Array::from(i8_values));
@@ -2857,6 +3031,40 @@ mod tests {
}
}
+ #[test]
+ fn test_cast_list_containers() {
+ // large-list to list
+ let array = Arc::new(make_large_list_array()) as ArrayRef;
+ let list_array = cast(
+ &array,
+ &DataType::List(Box::new(Field::new("", DataType::Int32, false))),
+ )
+ .unwrap();
+ let actual = list_array.as_any().downcast_ref::<ListArray>().unwrap();
+ let expected =
array.as_any().downcast_ref::<LargeListArray>().unwrap();
+
+ assert_eq!(&expected.value(0), &actual.value(0));
+ assert_eq!(&expected.value(1), &actual.value(1));
+ assert_eq!(&expected.value(2), &actual.value(2));
+
+ // list to large-list
+ let array = Arc::new(make_list_array()) as ArrayRef;
+ let large_list_array = cast(
+ &array,
+ &DataType::LargeList(Box::new(Field::new("", DataType::Int32,
false))),
+ )
+ .unwrap();
+ let actual = large_list_array
+ .as_any()
+ .downcast_ref::<LargeListArray>()
+ .unwrap();
+ let expected = array.as_any().downcast_ref::<ListArray>().unwrap();
+
+ assert_eq!(&expected.value(0), &actual.value(0));
+ assert_eq!(&expected.value(1), &actual.value(1));
+ assert_eq!(&expected.value(2), &actual.value(2));
+ }
+
/// Create instances of arrays with varying types for cast tests
fn get_arrays_of_all_types() -> Vec<ArrayRef> {
let tz_name = String::from("America/New_York");