This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 810291179f Take kernel dyn Array (#4705)
810291179f is described below
commit 810291179f65d63a5c49ed6b7881bc5788d85a9e
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Aug 17 10:48:33 2023 +0100
Take kernel dyn Array (#4705)
---
arrow-cast/src/cast.rs | 16 +----
arrow-select/src/take.rs | 153 +++++++++++++++++++++++++++++++++++------------
2 files changed, 116 insertions(+), 53 deletions(-)
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index a08a7a4fd4..23b7a4b5a0 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -49,7 +49,7 @@ use crate::parse::{
use arrow_array::{
builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *,
};
-use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer, ScalarBuffer};
+use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer};
use arrow_data::ArrayData;
use arrow_schema::*;
use arrow_select::take::take;
@@ -3027,19 +3027,7 @@ where
{
let dict_array = array.as_dictionary::<K>();
let cast_dict_values = cast_with_options(dict_array.values(), to_type,
cast_options)?;
- let keys = dict_array.keys();
- match K::DATA_TYPE {
- DataType::Int32 => {
- // Dictionary guarantees all non-null keys >= 0
- let buffer = ScalarBuffer::new(keys.values().inner().clone(), 0,
keys.len());
- let indices = PrimitiveArray::new(buffer, keys.nulls().cloned());
- take::<UInt32Type>(cast_dict_values.as_ref(), &indices, None)
- }
- _ => {
- let indices = cast_with_options(keys, &DataType::UInt32,
cast_options)?;
- take::<UInt32Type>(cast_dict_values.as_ref(),
indices.as_primitive(), None)
- }
- }
+ take(cast_dict_values.as_ref(), dict_array.keys(), None)
}
/// Attempts to encode an array into an `ArrayDictionary` with index
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index cee9cbaf84..70b80e5878 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -73,49 +73,65 @@ use num::{One, Zero};
///
/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
/// ```
-pub fn take<IndexType: ArrowPrimitiveType>(
+pub fn take(
values: &dyn Array,
- indices: &PrimitiveArray<IndexType>,
+ indices: &dyn Array,
options: Option<TakeOptions>,
) -> Result<ArrayRef, ArrowError> {
- take_impl(values, indices, options)
+ let options = options.unwrap_or_default();
+ macro_rules! helper {
+ ($t:ty, $values:expr, $indices:expr, $options:expr) => {{
+ let indices = indices.as_primitive::<$t>();
+ if $options.check_bounds {
+ check_bounds($values.len(), indices)?;
+ }
+ let indices = indices.to_indices();
+ take_impl($values, &indices)
+ }};
+ }
+ downcast_integer! {
+ indices.data_type() => (helper, values, indices, options),
+ d => Err(ArrowError::InvalidArgumentError(format!("Take only supported
for integers, got {d:?}")))
+ }
+}
+
+/// Verifies that the non-null values of `indices` are all `< len`
+fn check_bounds<T: ArrowPrimitiveType>(
+ len: usize,
+ indices: &PrimitiveArray<T>,
+) -> Result<(), ArrowError> {
+ if indices.null_count() > 0 {
+ indices.iter().flatten().try_for_each(|index| {
+ let ix = index.to_usize().ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
+ if ix >= len {
+ return Err(ArrowError::ComputeError(
+ format!("Array index out of bounds, cannot get item at
index {ix} from {len} entries"))
+ );
+ }
+ Ok(())
+ })
+ } else {
+ indices.values().iter().try_for_each(|index| {
+ let ix = index.to_usize().ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
+ if ix >= len {
+ return Err(ArrowError::ComputeError(
+ format!("Array index out of bounds, cannot get item at
index {ix} from {len} entries"))
+ );
+ }
+ Ok(())
+ })
+ }
}
+#[inline(never)]
fn take_impl<IndexType: ArrowPrimitiveType>(
values: &dyn Array,
indices: &PrimitiveArray<IndexType>,
- options: Option<TakeOptions>,
) -> Result<ArrayRef, ArrowError> {
- let options = options.unwrap_or_default();
- if options.check_bounds {
- let len = values.len();
- if indices.null_count() > 0 {
- indices.iter().flatten().try_for_each(|index| {
- let ix = index.to_usize().ok_or_else(|| {
- ArrowError::ComputeError("Cast to usize
failed".to_string())
- })?;
- if ix >= len {
- return Err(ArrowError::ComputeError(
- format!("Array index out of bounds, cannot get item at
index {ix} from {len} entries"))
- );
- }
- Ok(())
- })?;
- } else {
- indices.values().iter().try_for_each(|index| {
- let ix = index.to_usize().ok_or_else(|| {
- ArrowError::ComputeError("Cast to usize
failed".to_string())
- })?;
- if ix >= len {
- return Err(ArrowError::ComputeError(
- format!("Array index out of bounds, cannot get item at
index {ix} from {len} entries"))
- );
- }
- Ok(())
- })?
- }
- }
-
downcast_primitive_array! {
values => Ok(Arc::new(take_primitive(values, indices)?)),
DataType::Boolean => {
@@ -156,7 +172,7 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
let arrays = array
.columns()
.iter()
- .map(|a| take_impl(a.as_ref(), indices, Some(options.clone())))
+ .map(|a| take_impl(a.as_ref(), indices))
.collect::<Result<Vec<ArrayRef>, _>>()?;
let fields: Vec<(FieldRef, ArrayRef)> =
fields.iter().cloned().zip(arrays).collect();
@@ -423,7 +439,7 @@ where
let (list_indices, offsets, null_buf) =
take_value_indices_from_list::<IndexType, OffsetType>(values,
indices)?;
- let taken = take_impl::<OffsetType>(values.values().as_ref(),
&list_indices, None)?;
+ let taken = take_impl::<OffsetType>(values.values().as_ref(),
&list_indices)?;
let value_offsets = Buffer::from_vec(offsets);
// create a new list with taken data and computed null information
let list_data = ArrayDataBuilder::new(values.data_type().clone())
@@ -449,7 +465,7 @@ fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
length: <UInt32Type as ArrowPrimitiveType>::Native,
) -> Result<FixedSizeListArray, ArrowError> {
let list_indices = take_value_indices_from_fixed_size_list(values,
indices, length)?;
- let taken = take_impl::<UInt32Type>(values.values().as_ref(),
&list_indices, None)?;
+ let taken = take_impl::<UInt32Type>(values.values().as_ref(),
&list_indices)?;
// determine null count and null buffer, which are a function of `values`
and `indices`
let num_bytes = bit_util::ceil(indices.len(), 8);
@@ -676,6 +692,65 @@ where
Ok(PrimitiveArray::<UInt32Type>::from(values))
}
+/// To avoid generating take implementations for every index type, instead we
+/// only generate for UInt32 and UInt64 and coerce inputs to these types
+trait ToIndices {
+ type T: ArrowPrimitiveType;
+
+ fn to_indices(&self) -> PrimitiveArray<Self::T>;
+}
+
+macro_rules! to_indices_reinterpret {
+ ($t:ty, $o:ty) => {
+ impl ToIndices for PrimitiveArray<$t> {
+ type T = $o;
+
+ fn to_indices(&self) -> PrimitiveArray<$o> {
+ let cast =
+ ScalarBuffer::new(self.values().inner().clone(), 0,
self.len());
+ PrimitiveArray::new(cast, self.nulls().cloned())
+ }
+ }
+ };
+}
+
+macro_rules! to_indices_identity {
+ ($t:ty) => {
+ impl ToIndices for PrimitiveArray<$t> {
+ type T = $t;
+
+ fn to_indices(&self) -> PrimitiveArray<$t> {
+ self.clone()
+ }
+ }
+ };
+}
+
+macro_rules! to_indices_widening {
+ ($t:ty, $o:ty) => {
+ impl ToIndices for PrimitiveArray<$t> {
+ type T = UInt32Type;
+
+ fn to_indices(&self) -> PrimitiveArray<$o> {
+ let cast = self.values().iter().copied().map(|x| x as
_).collect();
+ PrimitiveArray::new(cast, self.nulls().cloned())
+ }
+ }
+ };
+}
+
+to_indices_widening!(UInt8Type, UInt32Type);
+to_indices_widening!(Int8Type, UInt32Type);
+
+to_indices_widening!(UInt16Type, UInt32Type);
+to_indices_widening!(Int16Type, UInt32Type);
+
+to_indices_identity!(UInt32Type);
+to_indices_reinterpret!(Int32Type, UInt32Type);
+
+to_indices_identity!(UInt64Type);
+to_indices_reinterpret!(Int64Type, UInt64Type);
+
#[cfg(test)]
mod tests {
use super::*;
@@ -767,7 +842,7 @@ mod tests {
{
let output = PrimitiveArray::<T>::from(data);
let expected = PrimitiveArray::<T>::from(expected_data);
- let output = take_impl(&output, index, options).unwrap();
+ let output = take(&output, index, options).unwrap();
let output =
output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
assert_eq!(output, &expected)
}
@@ -1078,7 +1153,7 @@ mod tests {
1_639_715_368_000_000_000,
])
.with_timezone("UTC".to_string());
- let result = take_impl(&input, &index, None).unwrap();
+ let result = take(&input, &index, None).unwrap();
match result.data_type() {
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
assert_eq!(tz.clone(), Some("UTC".into()))