This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 5eeccab92 Use dyn Array in sort kernels (#3931)
5eeccab92 is described below
commit 5eeccab922c377a18e54cf39ad49a2e4d54ffaf2
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Mar 24 12:38:58 2023 +0000
Use dyn Array in sort kernels (#3931)
---
arrow-ord/src/sort.rs | 51 ++++++++++++++++++++++++---------------------------
1 file changed, 24 insertions(+), 27 deletions(-)
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index 9b17651f9..6e0becc36 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -47,22 +47,21 @@ pub use arrow_schema::SortOptions;
/// # Example
/// ```rust
/// # use std::sync::Arc;
-/// # use arrow_array::{Int32Array, ArrayRef};
+/// # use arrow_array::Int32Array;
/// # use arrow_ord::sort::sort;
-/// let array: ArrayRef = Arc::new(Int32Array::from(vec![5, 4, 3, 2, 1]));
+/// let array = Int32Array::from(vec![5, 4, 3, 2, 1]);
/// let sorted_array = sort(&array, None).unwrap();
-/// let sorted_array =
sorted_array.as_any().downcast_ref::<Int32Array>().unwrap();
-/// assert_eq!(sorted_array, &Int32Array::from(vec![1, 2, 3, 4, 5]));
+/// assert_eq!(sorted_array.as_ref(), &Int32Array::from(vec![1, 2, 3, 4, 5]));
/// ```
pub fn sort(
- values: &ArrayRef,
+ values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
if let DataType::RunEndEncoded(_, _) = values.data_type() {
return sort_run(values, options, None);
}
let indices = sort_to_indices(values, options, None)?;
- take(values.as_ref(), &indices, None)
+ take(values, &indices, None)
}
/// Sort the `ArrayRef` partially.
@@ -77,14 +76,13 @@ pub fn sort(
/// # Example
/// ```rust
/// # use std::sync::Arc;
-/// # use arrow_array::{Int32Array, ArrayRef};
+/// # use arrow_array::Int32Array;
/// # use arrow_ord::sort::{sort_limit, SortOptions};
-/// let array: ArrayRef = Arc::new(Int32Array::from(vec![5, 4, 3, 2, 1]));
+/// let array = Int32Array::from(vec![5, 4, 3, 2, 1]);
///
/// // Find the the top 2 items
/// let sorted_array = sort_limit(&array, None, Some(2)).unwrap();
-/// let sorted_array =
sorted_array.as_any().downcast_ref::<Int32Array>().unwrap();
-/// assert_eq!(sorted_array, &Int32Array::from(vec![1, 2]));
+/// assert_eq!(sorted_array.as_ref(), &Int32Array::from(vec![1, 2]));
///
/// // Find the bottom top 2 items
/// let options = Some(SortOptions {
@@ -92,11 +90,10 @@ pub fn sort(
/// ..Default::default()
/// });
/// let sorted_array = sort_limit(&array, options, Some(2)).unwrap();
-/// let sorted_array =
sorted_array.as_any().downcast_ref::<Int32Array>().unwrap();
-/// assert_eq!(sorted_array, &Int32Array::from(vec![5, 4]));
+/// assert_eq!(sorted_array.as_ref(), &Int32Array::from(vec![5, 4]));
/// ```
pub fn sort_limit(
- values: &ArrayRef,
+ values: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<ArrayRef, ArrowError> {
@@ -104,7 +101,7 @@ pub fn sort_limit(
return sort_run(values, options, limit);
}
let indices = sort_to_indices(values, options, limit)?;
- take(values.as_ref(), &indices, None)
+ take(values, &indices, None)
}
/// we can only do this if the T is primitive
@@ -128,7 +125,7 @@ where
}
// partition indices into valid and null indices
-fn partition_validity(array: &ArrayRef) -> (Vec<u32>, Vec<u32>) {
+fn partition_validity(array: &dyn Array) -> (Vec<u32>, Vec<u32>) {
match array.null_count() {
// faster path
0 => ((0..(array.len() as u32)).collect(), vec![]),
@@ -143,7 +140,7 @@ fn partition_validity(array: &ArrayRef) -> (Vec<u32>,
Vec<u32>) {
/// For floating point arrays any NaN values are considered to be greater than
any other non-null value
/// limit is an option for partial_sort
pub fn sort_to_indices(
- values: &ArrayRef,
+ values: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<UInt32Array, ArrowError> {
@@ -407,7 +404,7 @@ pub fn sort_to_indices(
/// and [tri-color
sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem)
/// can be used instead.
fn sort_boolean(
- values: &ArrayRef,
+ values: &dyn Array,
value_indices: Vec<u32>,
mut null_indices: Vec<u32>,
options: &SortOptions,
@@ -489,7 +486,7 @@ fn sort_boolean(
/// Sort primitive values
fn sort_primitive<T, F>(
- values: &ArrayRef,
+ values: &dyn Array,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
cmp: F,
@@ -638,7 +635,7 @@ fn insert_valid_values<T>(result_slice: &mut [u32], offset:
usize, valids: &[(u3
// will result in output RunArray { run_ends = [2,4,6,8], values = [1,1,2,2] }
// and not RunArray { run_ends = [4,8], values = [1,2] }
fn sort_run(
- values: &ArrayRef,
+ values: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<ArrayRef, ArrowError> {
@@ -656,7 +653,7 @@ fn sort_run(
}
fn sort_run_downcasted<R: RunEndIndexType>(
- values: &ArrayRef,
+ values: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<ArrayRef, ArrowError> {
@@ -719,7 +716,7 @@ fn sort_run_downcasted<R: RunEndIndexType>(
// logical indices and to get the run array back, the logical indices has to be
// encoded back to run array.
fn sort_run_to_indices<R: RunEndIndexType>(
- values: &ArrayRef,
+ values: &dyn Array,
options: &SortOptions,
limit: Option<usize>,
) -> UInt32Array {
@@ -819,7 +816,7 @@ where
/// Sort strings
fn sort_string<Offset: OffsetSizeTrait>(
- values: &ArrayRef,
+ values: &dyn Array,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
options: &SortOptions,
@@ -905,7 +902,7 @@ where
}
fn sort_list<S, T>(
- values: &ArrayRef,
+ values: &dyn Array,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
options: &SortOptions,
@@ -920,7 +917,7 @@ where
}
fn sort_list_inner<S>(
- values: &ArrayRef,
+ values: &dyn Array,
value_indices: Vec<u32>,
mut null_indices: Vec<u32>,
options: &SortOptions,
@@ -971,7 +968,7 @@ where
}
fn sort_binary<S>(
- values: &ArrayRef,
+ values: &dyn Array,
value_indices: Vec<u32>,
mut null_indices: Vec<u32>,
options: &SortOptions,
@@ -3217,7 +3214,7 @@ mod tests {
fn test_sort_run_inner<F>(sort_fn: F)
where
F: Fn(
- &ArrayRef,
+ &dyn Array,
Option<SortOptions>,
Option<usize>,
) -> Result<ArrayRef, ArrowError>,
@@ -3293,7 +3290,7 @@ mod tests {
sort_fn: &F,
) where
F: Fn(
- &ArrayRef,
+ &dyn Array,
Option<SortOptions>,
Option<usize>,
) -> Result<ArrayRef, ArrowError>,