This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 1d874fa ARROW-11005: [Rust] Remove indirection from `take` kernel
1d874fa is described below
commit 1d874fa2024e2f613b6c30f5cbe92ca97de02b1a
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Tue Dec 22 07:29:07 2020 -0500
ARROW-11005: [Rust] Remove indirection from `take` kernel
This PR is a small cleanup of the `take` kernel.
1. replaces `&Arc<Array>` by `&Array`. This is a small change in the API,
but makes it more obvious that the only requirements for `take` is implementing
the trait `Array`.
2. internal concrete implementations of `take` are no longer `dyn`, and
instead expect the corresponding array types and return the corresponding array
types.
3. Clarified in the documentation that not performing bound checks can lead
to undefined behavior.
4. Added equality for `FixedSizeListArray`, that was missing.
The rational for this PR is that it makes it more obvious the design of the
module: each array type has its own implementation (e.g.
`take_primitive<T>(array: &PrimitiveArray<T>, indices: ...) ->
PrimitiveArray<T>`), and there is one `dyn` implementation, `take(array:
&Array, indices: ...) -> ArrayRef`, that uses each type-specific implementation
and wraps it on an `Arc<Array>` for the `dyn` behavior. This is mostly for
simplicity and readability, but we could expose these methods publicly.
Closes #8985 from jorgecarleitao/simpler_take
Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Andrew Lamb <[email protected]>
---
rust/arrow/benches/take_kernels.rs | 26 +-
rust/arrow/src/array/equal/mod.rs | 11 +-
rust/arrow/src/compute/kernels/cast.rs | 2 +-
rust/arrow/src/compute/kernels/sort.rs | 4 +-
rust/arrow/src/compute/kernels/take.rs | 302 +++++++++++----------
.../datafusion/src/physical_plan/hash_aggregate.rs | 2 +-
rust/datafusion/src/physical_plan/sort.rs | 2 +-
7 files changed, 186 insertions(+), 163 deletions(-)
diff --git a/rust/arrow/benches/take_kernels.rs
b/rust/arrow/benches/take_kernels.rs
index 1df68dd..3c5d254 100644
--- a/rust/arrow/benches/take_kernels.rs
+++ b/rust/arrow/benches/take_kernels.rs
@@ -22,8 +22,6 @@ use criterion::Criterion;
use rand::distributions::{Alphanumeric, Distribution, Standard};
use rand::Rng;
-use std::sync::Arc;
-
extern crate arrow;
use arrow::array::*;
@@ -32,36 +30,32 @@ use arrow::datatypes::*;
use arrow::util::test_util::seedable_rng;
// cast array from specified primitive array type to desired data type
-fn create_primitive<T>(size: usize) -> ArrayRef
+fn create_primitive<T>(size: usize) -> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
Standard: Distribution<T::Native>,
PrimitiveArray<T>: std::convert::From<Vec<T::Native>>,
{
- let array: PrimitiveArray<T> = seedable_rng()
+ seedable_rng()
.sample_iter(&Standard)
.take(size)
.map(Some)
- .collect();
-
- Arc::new(array) as ArrayRef
+ .collect()
}
// cast array from specified primitive array type to desired data type
-fn create_boolean(size: usize) -> ArrayRef
+fn create_boolean(size: usize) -> BooleanArray
where
Standard: Distribution<bool>,
{
- let array: BooleanArray = seedable_rng()
+ seedable_rng()
.sample_iter(&Standard)
.take(size)
.map(Some)
- .collect();
-
- Arc::new(array) as ArrayRef
+ .collect()
}
-fn create_strings(size: usize, null_density: f32) -> ArrayRef {
+fn create_strings(size: usize, null_density: f32) -> StringArray {
let rng = &mut seedable_rng();
let mut builder = StringBuilder::new(size);
@@ -74,7 +68,7 @@ fn create_strings(size: usize, null_density: f32) -> ArrayRef
{
builder.append_null().unwrap()
}
}
- Arc::new(builder.finish())
+ builder.finish()
}
fn create_random_index(size: usize, null_density: f32) -> UInt32Array {
@@ -91,8 +85,8 @@ fn create_random_index(size: usize, null_density: f32) ->
UInt32Array {
builder.finish()
}
-fn bench_take(values: &ArrayRef, indices: &UInt32Array) {
- criterion::black_box(take(&values, &indices, None).unwrap());
+fn bench_take(values: &dyn Array, indices: &UInt32Array) {
+ criterion::black_box(take(values, &indices, None).unwrap());
}
fn add_benchmark(c: &mut Criterion) {
diff --git a/rust/arrow/src/array/equal/mod.rs
b/rust/arrow/src/array/equal/mod.rs
index f66c0a7..3574a80 100644
--- a/rust/arrow/src/array/equal/mod.rs
+++ b/rust/arrow/src/array/equal/mod.rs
@@ -21,8 +21,9 @@
use super::{
Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray,
- FixedSizeBinaryArray, GenericBinaryArray, GenericListArray,
GenericStringArray,
- NullArray, OffsetSizeTrait, PrimitiveArray, StringOffsetSizeTrait,
StructArray,
+ FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray,
GenericListArray,
+ GenericStringArray, NullArray, OffsetSizeTrait, PrimitiveArray,
+ StringOffsetSizeTrait, StructArray,
};
use crate::{
@@ -116,6 +117,12 @@ impl<OffsetSize: OffsetSizeTrait> PartialEq for
GenericListArray<OffsetSize> {
}
}
+impl PartialEq for FixedSizeListArray {
+ fn eq(&self, other: &Self) -> bool {
+ equal(self.data().as_ref(), other.data().as_ref())
+ }
+}
+
impl PartialEq for StructArray {
fn eq(&self, other: &Self) -> bool {
equal(self.data().as_ref(), other.data().as_ref())
diff --git a/rust/arrow/src/compute/kernels/cast.rs
b/rust/arrow/src/compute/kernels/cast.rs
index c2d4aa2..1dfcc1b 100644
--- a/rust/arrow/src/compute/kernels/cast.rs
+++ b/rust/arrow/src/compute/kernels/cast.rs
@@ -1123,7 +1123,7 @@ where
)
})?;
- take(&cast_dict_values, u32_indicies, None)
+ take(cast_dict_values.as_ref(), u32_indicies, None)
}
/// Attempts to encode an array into an `ArrayDictionary` with index
diff --git a/rust/arrow/src/compute/kernels/sort.rs
b/rust/arrow/src/compute/kernels/sort.rs
index ad4828b..f42d1a8 100644
--- a/rust/arrow/src/compute/kernels/sort.rs
+++ b/rust/arrow/src/compute/kernels/sort.rs
@@ -38,7 +38,7 @@ use TimeUnit::*;
///
pub fn sort(values: &ArrayRef, options: Option<SortOptions>) ->
Result<ArrayRef> {
let indices = sort_to_indices(values, options)?;
- take(values, &indices, None)
+ take(values.as_ref(), &indices, None)
}
// partition indices into non-NaN and NaN
@@ -626,7 +626,7 @@ pub fn lexsort(columns: &[SortColumn]) ->
Result<Vec<ArrayRef>> {
let indices = lexsort_to_indices(columns)?;
columns
.iter()
- .map(|c| take(&c.values, &indices, None))
+ .map(|c| take(c.values.as_ref(), &indices, None))
.collect()
}
diff --git a/rust/arrow/src/compute/kernels/take.rs
b/rust/arrow/src/compute/kernels/take.rs
index e0aebf8..c1d31f9 100644
--- a/rust/arrow/src/compute/kernels/take.rs
+++ b/rust/arrow/src/compute/kernels/take.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-//! Defines take kernel for `ArrayRef`
+//! Defines take kernel for [Array]
use std::{ops::AddAssign, sync::Arc};
@@ -31,31 +31,53 @@ use crate::{array::*, buffer::buffer_bin_and};
use num::{ToPrimitive, Zero};
use TimeUnit::*;
-/// Take elements from `ArrayRef` by copying the data from `values` at
-/// each index in `indices` into a new `ArrayRef`
+macro_rules! downcast_take {
+ ($type: ty, $values: expr, $indices: expr) => {{
+ let values = $values
+ .as_any()
+ .downcast_ref::<PrimitiveArray<$type>>()
+ .expect("Unable to downcast to a primitive array");
+ Ok(Arc::new(take_primitive::<$type, _>(&values, $indices)?))
+ }};
+}
+
+macro_rules! downcast_dict_take {
+ ($type: ty, $values: expr, $indices: expr) => {{
+ let values = $values
+ .as_any()
+ .downcast_ref::<DictionaryArray<$type>>()
+ .expect("Unable to downcast to a dictionary array");
+ Ok(Arc::new(take_dict::<$type, _>(values, $indices)?))
+ }};
+}
+
+/// Take elements by index from [Array], creating a new [Array] from those
indexes.
///
-/// For example:
+/// # Errors
+/// This function errors whenever:
+/// * An index cannot be casted to `usize` (typically 32 bit architectures)
+/// * An index is out of bounds and `options` is set to check bounds.
+/// # Safety
+/// When `options` is not set to check bounds (default), taking indexes after
`len` is undefined behavior.
+/// # Examples
/// ```
-/// use std::sync::Arc;
-/// use arrow::array::{Array, StringArray, UInt32Array};
+/// use arrow::array::{StringArray, UInt32Array};
+/// use arrow::error::Result;
/// use arrow::compute::take;
-///
+/// # fn main() -> Result<()> {
/// let values = StringArray::from(vec!["zero", "one", "two"]);
-/// let values: Arc<dyn Array> = Arc::new(values);
///
/// // Take items at index 2, and 1:
/// let indices = UInt32Array::from(vec![2, 1]);
-/// let taken = take(&values, &indices, None).unwrap();
+/// let taken = take(&values, &indices, None)?;
/// let taken = taken.as_any().downcast_ref::<StringArray>().unwrap();
///
/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
+/// # Ok(())
+/// # }
/// ```
-///
-/// Supports:
-/// * null indices, returning a null value for the index
-/// * checking for overflowing indices
pub fn take(
- values: &ArrayRef,
+ values: &Array,
indices: &UInt32Array,
options: Option<TakeOptions>,
) -> Result<ArrayRef> {
@@ -63,7 +85,7 @@ pub fn take(
}
fn take_impl<IndexType>(
- values: &ArrayRef,
+ values: &Array,
indices: &PrimitiveArray<IndexType>,
options: Option<TakeOptions>,
) -> Result<ArrayRef>
@@ -88,67 +110,100 @@ where
}
}
match values.data_type() {
- DataType::Boolean => take_boolean(values, indices),
- DataType::Int8 => take_primitive::<Int8Type, _>(values, indices),
- DataType::Int16 => take_primitive::<Int16Type, _>(values, indices),
- DataType::Int32 => take_primitive::<Int32Type, _>(values, indices),
- DataType::Int64 => take_primitive::<Int64Type, _>(values, indices),
- DataType::UInt8 => take_primitive::<UInt8Type, _>(values, indices),
- DataType::UInt16 => take_primitive::<UInt16Type, _>(values, indices),
- DataType::UInt32 => take_primitive::<UInt32Type, _>(values, indices),
- DataType::UInt64 => take_primitive::<UInt64Type, _>(values, indices),
- DataType::Float32 => take_primitive::<Float32Type, _>(values, indices),
- DataType::Float64 => take_primitive::<Float64Type, _>(values, indices),
- DataType::Date32(_) => take_primitive::<Date32Type, _>(values,
indices),
- DataType::Date64(_) => take_primitive::<Date64Type, _>(values,
indices),
- DataType::Time32(Second) => {
- take_primitive::<Time32SecondType, _>(values, indices)
+ DataType::Boolean => {
+ let values =
values.as_any().downcast_ref::<BooleanArray>().unwrap();
+ Ok(Arc::new(take_boolean(values, indices)?))
}
+ DataType::Int8 => downcast_take!(Int8Type, values, indices),
+ DataType::Int16 => downcast_take!(Int16Type, values, indices),
+ DataType::Int32 => downcast_take!(Int32Type, values, indices),
+ DataType::Int64 => downcast_take!(Int64Type, values, indices),
+ DataType::UInt8 => downcast_take!(UInt8Type, values, indices),
+ DataType::UInt16 => downcast_take!(UInt16Type, values, indices),
+ DataType::UInt32 => downcast_take!(UInt32Type, values, indices),
+ DataType::UInt64 => downcast_take!(UInt64Type, values, indices),
+ DataType::Float32 => downcast_take!(Float32Type, values, indices),
+ DataType::Float64 => downcast_take!(Float64Type, values, indices),
+ DataType::Date32(_) => downcast_take!(Date32Type, values, indices),
+ DataType::Date64(_) => downcast_take!(Date64Type, values, indices),
+ DataType::Time32(Second) => downcast_take!(Time32SecondType, values,
indices),
DataType::Time32(Millisecond) => {
- take_primitive::<Time32MillisecondType, _>(values, indices)
+ downcast_take!(Time32MillisecondType, values, indices)
}
DataType::Time64(Microsecond) => {
- take_primitive::<Time64MicrosecondType, _>(values, indices)
+ downcast_take!(Time64MicrosecondType, values, indices)
}
DataType::Time64(Nanosecond) => {
- take_primitive::<Time64NanosecondType, _>(values, indices)
+ downcast_take!(Time64NanosecondType, values, indices)
}
DataType::Timestamp(Second, _) => {
- take_primitive::<TimestampSecondType, _>(values, indices)
+ downcast_take!(TimestampSecondType, values, indices)
}
DataType::Timestamp(Millisecond, _) => {
- take_primitive::<TimestampMillisecondType, _>(values, indices)
+ downcast_take!(TimestampMillisecondType, values, indices)
}
DataType::Timestamp(Microsecond, _) => {
- take_primitive::<TimestampMicrosecondType, _>(values, indices)
+ downcast_take!(TimestampMicrosecondType, values, indices)
}
DataType::Timestamp(Nanosecond, _) => {
- take_primitive::<TimestampNanosecondType, _>(values, indices)
+ downcast_take!(TimestampNanosecondType, values, indices)
}
DataType::Interval(IntervalUnit::YearMonth) => {
- take_primitive::<IntervalYearMonthType, _>(values, indices)
+ downcast_take!(IntervalYearMonthType, values, indices)
}
DataType::Interval(IntervalUnit::DayTime) => {
- take_primitive::<IntervalDayTimeType, _>(values, indices)
+ downcast_take!(IntervalDayTimeType, values, indices)
}
DataType::Duration(TimeUnit::Second) => {
- take_primitive::<DurationSecondType, _>(values, indices)
+ downcast_take!(DurationSecondType, values, indices)
}
DataType::Duration(TimeUnit::Millisecond) => {
- take_primitive::<DurationMillisecondType, _>(values, indices)
+ downcast_take!(DurationMillisecondType, values, indices)
}
DataType::Duration(TimeUnit::Microsecond) => {
- take_primitive::<DurationMicrosecondType, _>(values, indices)
+ downcast_take!(DurationMicrosecondType, values, indices)
}
DataType::Duration(TimeUnit::Nanosecond) => {
- take_primitive::<DurationNanosecondType, _>(values, indices)
+ downcast_take!(DurationNanosecondType, values, indices)
+ }
+ DataType::Utf8 => {
+ let values = values
+ .as_any()
+ .downcast_ref::<GenericStringArray<i32>>()
+ .unwrap();
+ Ok(Arc::new(take_string::<i32, _>(values, indices)?))
+ }
+ DataType::LargeUtf8 => {
+ let values = values
+ .as_any()
+ .downcast_ref::<GenericStringArray<i64>>()
+ .unwrap();
+ Ok(Arc::new(take_string::<i64, _>(values, indices)?))
+ }
+ DataType::List(_) => {
+ let values = values
+ .as_any()
+ .downcast_ref::<GenericListArray<i32>>()
+ .unwrap();
+ Ok(Arc::new(take_list::<_, Int32Type>(values, indices)?))
+ }
+ DataType::LargeList(_) => {
+ let values = values
+ .as_any()
+ .downcast_ref::<GenericListArray<i64>>()
+ .unwrap();
+ Ok(Arc::new(take_list::<_, Int64Type>(values, indices)?))
}
- DataType::Utf8 => take_string::<i32, _>(values, indices),
- DataType::LargeUtf8 => take_string::<i64, _>(values, indices),
- DataType::List(_) => take_list::<_, Int32Type>(values, indices),
- DataType::LargeList(_) => take_list::<_, Int64Type>(values, indices),
DataType::FixedSizeList(_, length) => {
- take_fixed_size_list(values, indices, *length as u32)
+ let values = values
+ .as_any()
+ .downcast_ref::<FixedSizeListArray>()
+ .unwrap();
+ Ok(Arc::new(take_fixed_size_list(
+ values,
+ indices,
+ *length as u32,
+ )?))
}
DataType::Struct(fields) => {
let struct_: &StructArray =
@@ -156,7 +211,7 @@ where
let arrays: Result<Vec<ArrayRef>> = struct_
.columns()
.iter()
- .map(|a| take_impl(a, indices, Some(options.clone())))
+ .map(|a| take_impl(a.as_ref(), indices, Some(options.clone())))
.collect();
let arrays = arrays?;
let pairs: Vec<(Field, ArrayRef)> =
@@ -164,14 +219,14 @@ where
Ok(Arc::new(StructArray::from(pairs)) as ArrayRef)
}
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
- DataType::Int8 => take_dict::<Int8Type, _>(values, indices),
- DataType::Int16 => take_dict::<Int16Type, _>(values, indices),
- DataType::Int32 => take_dict::<Int32Type, _>(values, indices),
- DataType::Int64 => take_dict::<Int64Type, _>(values, indices),
- DataType::UInt8 => take_dict::<UInt8Type, _>(values, indices),
- DataType::UInt16 => take_dict::<UInt16Type, _>(values, indices),
- DataType::UInt32 => take_dict::<UInt32Type, _>(values, indices),
- DataType::UInt64 => take_dict::<UInt64Type, _>(values, indices),
+ DataType::Int8 => downcast_dict_take!(Int8Type, values, indices),
+ DataType::Int16 => downcast_dict_take!(Int16Type, values, indices),
+ DataType::Int32 => downcast_dict_take!(Int32Type, values, indices),
+ DataType::Int64 => downcast_dict_take!(Int64Type, values, indices),
+ DataType::UInt8 => downcast_dict_take!(UInt8Type, values, indices),
+ DataType::UInt16 => downcast_dict_take!(UInt16Type, values,
indices),
+ DataType::UInt32 => downcast_dict_take!(UInt32Type, values,
indices),
+ DataType::UInt64 => downcast_dict_take!(UInt64Type, values,
indices),
t => unimplemented!("Take not supported for dictionary key type
{:?}", t),
},
t => unimplemented!("Take not supported for data type {:?}", t),
@@ -195,7 +250,7 @@ impl Default for TakeOptions {
}
}
-/// `take` implementation for all primitive arrays except boolean
+/// `take` implementation for all primitive arrays
///
/// This checks if an `indices` slot is populated, and gets the value from
`values`
/// as the populated index.
@@ -205,9 +260,9 @@ impl Default for TakeOptions {
/// indices: [0, null, 4, 3]
/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
fn take_primitive<T, I>(
- values: &ArrayRef,
+ values: &PrimitiveArray<T>,
indices: &PrimitiveArray<I>,
-) -> Result<ArrayRef>
+) -> Result<PrimitiveArray<T>>
where
T: ArrowPrimitiveType,
T::Native: num::Num,
@@ -216,24 +271,20 @@ where
{
let data_len = indices.len();
- let array = values.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
-
- let null_count = array.null_count();
-
let mut buffer = MutableBuffer::new(data_len *
std::mem::size_of::<T::Native>());
buffer.resize(data_len * std::mem::size_of::<T::Native>());
let data = buffer.typed_data_mut();
let nulls;
- if null_count == 0 {
+ if values.null_count() == 0 {
// Take indices without null checking
for (i, elem) in data.iter_mut().enumerate() {
let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(||
{
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;
- *elem = array.value(index);
+ *elem = values.value(index);
}
nulls = indices.data_ref().null_buffer().cloned();
} else {
@@ -247,11 +298,11 @@ where
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;
- if array.is_null(index) {
+ if values.is_null(index) {
bit_util::unset_bit(null_slice, i);
}
- *elem = array.value(index);
+ *elem = values.value(index);
}
nulls = match indices.data_ref().null_buffer() {
Some(buffer) => Some(buffer_bin_and(
@@ -274,28 +325,26 @@ where
vec![buffer.freeze()],
vec![],
);
- Ok(Arc::new(PrimitiveArray::<T>::from(Arc::new(data))))
+ Ok(PrimitiveArray::<T>::from(Arc::new(data)))
}
/// `take` implementation for boolean arrays
fn take_boolean<IndexType>(
- values: &ArrayRef,
+ values: &BooleanArray,
indices: &PrimitiveArray<IndexType>,
-) -> Result<ArrayRef>
+) -> Result<BooleanArray>
where
IndexType: ArrowNumericType,
IndexType::Native: ToPrimitive,
{
let data_len = indices.len();
- let array = values.as_any().downcast_ref::<BooleanArray>().unwrap();
-
let num_byte = bit_util::ceil(data_len, 8);
let mut val_buf = MutableBuffer::new(num_byte).with_bitset(num_byte,
false);
let val_slice = val_buf.data_mut();
- let null_count = array.null_count();
+ let null_count = values.null_count();
let nulls;
if null_count == 0 {
@@ -304,7 +353,7 @@ where
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;
- if array.value(index) {
+ if values.value(index) {
bit_util::set_bit(val_slice, i);
}
@@ -321,9 +370,9 @@ where
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;
- if array.is_null(index) {
+ if values.is_null(index) {
bit_util::unset_bit(null_slice, i);
- } else if array.value(index) {
+ } else if values.value(index) {
bit_util::set_bit(val_slice, i);
}
@@ -351,14 +400,14 @@ where
vec![val_buf.freeze()],
vec![],
);
- Ok(Arc::new(BooleanArray::from(Arc::new(data))))
+ Ok(BooleanArray::from(Arc::new(data)))
}
/// `take` implementation for string arrays
fn take_string<OffsetSize, IndexType>(
- values: &ArrayRef,
+ array: &GenericStringArray<OffsetSize>,
indices: &PrimitiveArray<IndexType>,
-) -> Result<ArrayRef>
+) -> Result<GenericStringArray<OffsetSize>>
where
OffsetSize: Zero + AddAssign + StringOffsetSizeTrait,
IndexType: ArrowNumericType,
@@ -366,11 +415,6 @@ where
{
let data_len = indices.len();
- let array = values
- .as_any()
- .downcast_ref::<GenericStringArray<OffsetSize>>()
- .unwrap();
-
let bytes_offset = (data_len + 1) * std::mem::size_of::<OffsetSize>();
let mut offsets_buffer = MutableBuffer::new(bytes_offset);
offsets_buffer.resize(bytes_offset);
@@ -470,9 +514,7 @@ where
if let Some(null_buffer) = nulls {
data = data.null_bit_buffer(null_buffer);
}
- Ok(Arc::new(GenericStringArray::<OffsetSize>::from(
- data.build(),
- )))
+ Ok(GenericStringArray::<OffsetSize>::from(data.build()))
}
/// `take` implementation for list arrays
@@ -481,9 +523,9 @@ where
/// applying `take` on the inner array, then reconstructing a list array
/// with the indexed offsets
fn take_list<IndexType, OffsetType>(
- values: &ArrayRef,
+ values: &GenericListArray<OffsetType::Native>,
indices: &PrimitiveArray<IndexType>,
-) -> Result<ArrayRef>
+) -> Result<GenericListArray<OffsetType::Native>>
where
IndexType: ArrowNumericType,
IndexType::Native: ToPrimitive,
@@ -493,15 +535,10 @@ where
{
// TODO: Some optimizations can be done here such as if it is
// taking the whole list or a contiguous sublist
- let list = values
- .as_any()
- .downcast_ref::<GenericListArray<OffsetType::Native>>()
- .unwrap();
-
let (list_indices, offsets) =
- take_value_indices_from_list::<IndexType, OffsetType>(list, indices)?;
+ take_value_indices_from_list::<IndexType, OffsetType>(values,
indices)?;
- let taken = take_impl::<OffsetType>(&list.values(), &list_indices, None)?;
+ let taken = take_impl::<OffsetType>(values.values().as_ref(),
&list_indices, None)?;
// determine null count and null buffer, which are a function of `values`
and `indices`
let mut null_count = 0;
let num_bytes = bit_util::ceil(indices.len(), 8);
@@ -520,7 +557,7 @@ where
}
let value_offsets = Buffer::from(offsets[..].to_byte_slice());
// create a new list with taken data and computed null information
- let list_data = ArrayDataBuilder::new(list.data_type().clone())
+ let list_data = ArrayDataBuilder::new(values.data_type().clone())
.len(indices.len())
.null_count(null_count)
.null_bit_buffer(null_buf.freeze())
@@ -528,9 +565,7 @@ where
.add_child_data(taken.data())
.add_buffer(value_offsets)
.build();
- let list_array =
- Arc::new(GenericListArray::<OffsetType::Native>::from(list_data)) as
ArrayRef;
- Ok(list_array)
+ Ok(GenericListArray::<OffsetType::Native>::from(list_data))
}
/// `take` implementation for `FixedSizeListArray`
@@ -539,21 +574,16 @@ where
/// applying `take` on the inner array, then reconstructing a list array
/// with the indexed offsets
fn take_fixed_size_list<IndexType>(
- values: &ArrayRef,
+ values: &FixedSizeListArray,
indices: &PrimitiveArray<IndexType>,
length: <UInt32Type as ArrowPrimitiveType>::Native,
-) -> Result<ArrayRef>
+) -> Result<FixedSizeListArray>
where
IndexType: ArrowNumericType,
IndexType::Native: ToPrimitive,
{
- let list = values
- .as_any()
- .downcast_ref::<FixedSizeListArray>()
- .unwrap();
-
- let list_indices = take_value_indices_from_fixed_size_list(list, indices,
length)?;
- let taken = take_impl::<UInt32Type>(&list.values(), &list_indices, None)?;
+ let list_indices = take_value_indices_from_fixed_size_list(values,
indices, length)?;
+ let taken = take_impl::<UInt32Type>(values.values().as_ref(),
&list_indices, None)?;
// determine null count and null buffer, which are a function of `values`
and `indices`
let mut null_count = 0;
@@ -565,13 +595,13 @@ where
let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
ArrowError::ComputeError("Cast to usize failed".to_string())
})?;
- if !indices.is_valid(i) || list.is_null(index) {
+ if !indices.is_valid(i) || values.is_null(index) {
bit_util::unset_bit(null_slice, i);
null_count += 1;
}
}
- let list_data = ArrayDataBuilder::new(list.data_type().clone())
+ let list_data = ArrayDataBuilder::new(values.data_type().clone())
.len(indices.len())
.null_count(null_count)
.null_bit_buffer(null_buf.freeze())
@@ -579,39 +609,37 @@ where
.add_child_data(taken.data())
.build();
- Ok(Arc::new(FixedSizeListArray::from(list_data)))
+ Ok(FixedSizeListArray::from(list_data))
}
/// `take` implementation for dictionary arrays
///
/// applies `take` to the keys of the dictionary array and returns a new
dictionary array
/// with the same dictionary values and reordered keys
-fn take_dict<T, I>(values: &ArrayRef, indices: &PrimitiveArray<I>) ->
Result<ArrayRef>
+fn take_dict<T, I>(
+ values: &DictionaryArray<T>,
+ indices: &PrimitiveArray<I>,
+) -> Result<DictionaryArray<T>>
where
T: ArrowPrimitiveType,
T::Native: num::Num,
I: ArrowNumericType,
I::Native: ToPrimitive,
{
- let dict = values
- .as_any()
- .downcast_ref::<DictionaryArray<T>>()
- .unwrap();
- let keys: ArrayRef = Arc::new(dict.keys_array());
- let new_keys = take_primitive::<T, I>(&keys, indices)?;
+ let new_keys = take_primitive::<T, I>(&values.keys_array(), indices)?;
let new_keys_data = new_keys.data_ref();
let data = Arc::new(ArrayData::new(
- dict.data_type().clone(),
+ values.data_type().clone(),
new_keys.len(),
Some(new_keys_data.null_count()),
new_keys_data.null_buffer().cloned(),
0,
new_keys_data.buffers().to_vec(),
- dict.data().child_data().to_vec(),
+ values.data().child_data().to_vec(),
));
- Ok(Arc::new(DictionaryArray::<T>::from(data)))
+ Ok(DictionaryArray::<T>::from(data))
}
#[cfg(test)]
@@ -627,7 +655,7 @@ mod tests {
) {
let output = BooleanArray::from(data);
let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
- let output = take(&(Arc::new(output) as ArrayRef), index,
options).unwrap();
+ let output = take(&output, index, options).unwrap();
assert_eq!(&output, &expected)
}
@@ -642,7 +670,7 @@ mod tests {
{
let output = PrimitiveArray::<T>::from(data);
let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as
ArrayRef;
- let output = take(&(Arc::new(output) as ArrayRef), index,
options).unwrap();
+ let output = take(&output, index, options).unwrap();
assert_eq!(&output, &expected)
}
@@ -659,13 +687,13 @@ mod tests {
{
let output = PrimitiveArray::<T>::from(data);
let expected = PrimitiveArray::<T>::from(expected_data);
- let output = take_impl(&(Arc::new(output) as ArrayRef), index,
options).unwrap();
+ let output = take_impl(&output, index, options).unwrap();
let output =
output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
assert_eq!(output, &expected)
}
// create a simple struct for testing purposes
- fn create_test_struct() -> ArrayRef {
+ fn create_test_struct() -> StructArray {
let boolean_data = BooleanArray::from(vec![true, false, false,
true]).data();
let int_data = Int32Array::from(vec![42, 28, 19, 31]).data();
let mut field_types = vec![];
@@ -677,8 +705,7 @@ mod tests {
.add_child_data(boolean_data)
.add_child_data(int_data)
.build();
- let struct_array = StructArray::from(struct_array_data);
- Arc::new(struct_array) as ArrayRef
+ StructArray::from(struct_array_data)
}
#[test]
@@ -913,8 +940,6 @@ mod tests {
Some("four"),
Some("five"),
]);
- let array = Arc::new(array) as ArrayRef;
-
let actual = take(&array, &index, None).unwrap();
assert_eq!(actual.len(), index.len());
@@ -954,7 +979,7 @@ mod tests {
.add_buffer(value_offsets)
.add_child_data(value_data)
.build();
- let list_array = Arc::new($list_array_type::from(list_data)) as
ArrayRef;
+ let list_array = $list_array_type::from(list_data);
// index returns: [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]]
let index = UInt32Array::from(vec![Some(2), None, Some(1),
Some(2), Some(0)]);
@@ -1029,7 +1054,7 @@ mod tests {
.null_bit_buffer(Buffer::from([0b10111101, 0b00000000]))
.add_child_data(value_data)
.build();
- let list_array = Arc::new($list_array_type::from(list_data)) as
ArrayRef;
+ let list_array = $list_array_type::from(list_data);
// index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
let index = UInt32Array::from(vec![Some(2), None, Some(1),
Some(3), Some(0)]);
@@ -1102,7 +1127,7 @@ mod tests {
.null_bit_buffer(Buffer::from([0b01111101]))
.add_child_data(value_data)
.build();
- let list_array = Arc::new($list_array_type::from(list_data)) as
ArrayRef;
+ let list_array = $list_array_type::from(list_data);
// index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
let index = UInt32Array::from(vec![Some(2), None, Some(1),
Some(3), Some(0)]);
@@ -1157,13 +1182,11 @@ mod tests {
{
let indices = UInt32Array::from(indices);
- let input_array: ArrayRef =
- Arc::new(build_fixed_size_list_nullable::<T>(input_data, length));
+ let input_array = build_fixed_size_list_nullable::<T>(input_data,
length);
let output = take_fixed_size_list(&input_array, &indices, length as
u32).unwrap();
- let expected: ArrayRef =
- Arc::new(build_fixed_size_list_nullable::<T>(expected_data,
length));
+ let expected = build_fixed_size_list_nullable::<T>(expected_data,
length);
assert_eq!(&output, &expected)
}
@@ -1269,7 +1292,7 @@ mod tests {
.add_buffer(value_offsets)
.add_child_data(value_data)
.build();
- let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef;
+ let list_array = ListArray::from(list_data);
let index = UInt32Array::from(vec![1000]);
@@ -1371,7 +1394,6 @@ mod tests {
let array = dict_builder.finish();
let dict_values = array.values().clone();
let dict_values =
dict_values.as_any().downcast_ref::<StringArray>().unwrap();
- let array: Arc<dyn Array> = Arc::new(array);
let indices = UInt32Array::from(vec![
Some(0), // first "foo"
diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs
b/rust/datafusion/src/physical_plan/hash_aggregate.rs
index 1ce3497..9677246 100644
--- a/rust/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs
@@ -303,7 +303,7 @@ fn group_aggregate_batch(
.map(|array| {
// 2.3
compute::take(
- array,
+ array.as_ref(),
&UInt32Array::from(indices.clone()),
None, // None: no index check
)
diff --git a/rust/datafusion/src/physical_plan/sort.rs
b/rust/datafusion/src/physical_plan/sort.rs
index 67873e9..52c7aaa 100644
--- a/rust/datafusion/src/physical_plan/sort.rs
+++ b/rust/datafusion/src/physical_plan/sort.rs
@@ -167,7 +167,7 @@ fn sort_batches(
.iter()
.map(|column| {
take(
- column,
+ column.as_ref(),
&indices,
// disable bound check overhead since indices are already
generated from
// the same record batch