This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new a81da6c89c Cleanup sort (#4613)
a81da6c89c is described below
commit a81da6c89c68507cfb0b37a057dcecd7ba582d9b
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Aug 4 15:28:18 2023 +0100
Cleanup sort (#4613)
* Cleanup sort
* Add inline
* Further cleanup
* Further sort benchmark fixes
---
arrow-array/src/cast.rs | 17 +
arrow-ord/src/sort.rs | 785 +++++++++----------------------------------
arrow/benches/sort_kernel.rs | 26 +-
3 files changed, 196 insertions(+), 632 deletions(-)
diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs
index bee8823d1f..66b40d5b8e 100644
--- a/arrow-array/src/cast.rs
+++ b/arrow-array/src/cast.rs
@@ -799,6 +799,15 @@ pub trait AsArray: private::Sealed {
self.as_list_opt().expect("list array")
}
+ /// Downcast this to a [`FixedSizeBinaryArray`] returning `None` if not
possible
+ fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray>;
+
+ /// Downcast this to a [`FixedSizeBinaryArray`] panicking if not possible
+ fn as_fixed_size_binary(&self) -> &FixedSizeBinaryArray {
+ self.as_fixed_size_binary_opt()
+ .expect("fixed size binary array")
+ }
+
/// Downcast this to a [`FixedSizeListArray`] returning `None` if not
possible
fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray>;
@@ -848,6 +857,10 @@ impl AsArray for dyn Array + '_ {
self.as_any().downcast_ref()
}
+ fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> {
+ self.as_any().downcast_ref()
+ }
+
fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> {
self.as_any().downcast_ref()
}
@@ -885,6 +898,10 @@ impl AsArray for ArrayRef {
self.as_ref().as_list_opt()
}
+ fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> {
+ self.as_ref().as_fixed_size_binary_opt()
+ }
+
fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> {
self.as_ref().as_fixed_size_list_opt()
}
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index c3e9e26ec0..648a7d7afc 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -23,10 +23,9 @@ use arrow_array::cast::*;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::BooleanBufferBuilder;
-use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer};
-use arrow_data::ArrayData;
+use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_data::ArrayDataBuilder;
-use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit};
+use arrow_schema::{ArrowError, DataType};
use arrow_select::take::take;
use std::cmp::Ordering;
use std::sync::Arc;
@@ -181,13 +180,6 @@ where
}
}
-fn cmp<T>(l: T, r: T) -> Ordering
-where
- T: Ord,
-{
- l.cmp(&r)
-}
-
// partition indices into valid and null indices
fn partition_validity(array: &dyn Array) -> (Vec<u32>, Vec<u32>) {
match array.null_count() {
@@ -204,210 +196,33 @@ fn partition_validity(array: &dyn Array) -> (Vec<u32>,
Vec<u32>) {
/// For floating point arrays any NaN values are considered to be greater than
any other non-null value.
/// `limit` is an option for [partial_sort].
pub fn sort_to_indices(
- values: &dyn Array,
+ array: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<UInt32Array, ArrowError> {
let options = options.unwrap_or_default();
- let (v, n) = partition_validity(values);
-
- Ok(match values.data_type() {
- DataType::Decimal128(_, _) => {
- sort_primitive::<Decimal128Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Decimal256(_, _) => {
- sort_primitive::<Decimal256Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Boolean => sort_boolean(values, v, n, &options, limit),
- DataType::Int8 => {
- sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int16 => {
- sort_primitive::<Int16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int32 => {
- sort_primitive::<Int32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int64 => {
- sort_primitive::<Int64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt8 => {
- sort_primitive::<UInt8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt16 => {
- sort_primitive::<UInt16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt32 => {
- sort_primitive::<UInt32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt64 => {
- sort_primitive::<UInt64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Float16 => sort_primitive::<Float16Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float32 => sort_primitive::<Float32Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float64 => sort_primitive::<Float64Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Date32 => {
- sort_primitive::<Date32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Date64 => {
- sort_primitive::<Date64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Time32(TimeUnit::Second) => {
- sort_primitive::<Time32SecondType, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Time32(TimeUnit::Millisecond) => {
- sort_primitive::<Time32MillisecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Microsecond) => {
- sort_primitive::<Time64MicrosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Nanosecond) => {
- sort_primitive::<Time64NanosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- sort_primitive::<TimestampSecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- sort_primitive::<TimestampMillisecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- sort_primitive::<TimestampMicrosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- sort_primitive::<TimestampNanosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Interval(IntervalUnit::YearMonth) => {
- sort_primitive::<IntervalYearMonthType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- sort_primitive::<IntervalDayTimeType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- sort_primitive::<IntervalMonthDayNanoType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Second) => {
- sort_primitive::<DurationSecondType, _>(values, v, n, cmp,
&options, limit)
+ let (v, n) = partition_validity(array);
+
+ Ok(downcast_primitive_array! {
+ array => sort_primitive(array, v, n, options, limit),
+ DataType::Boolean => sort_boolean(array.as_boolean(), v, n, options,
limit),
+ DataType::Utf8 => sort_bytes(array.as_string::<i32>(), v, n, options,
limit),
+ DataType::LargeUtf8 => sort_bytes(array.as_string::<i64>(), v, n,
options, limit),
+ DataType::Binary => sort_bytes(array.as_binary::<i32>(), v, n,
options, limit),
+ DataType::LargeBinary => sort_bytes(array.as_binary::<i64>(), v, n,
options, limit),
+ DataType::FixedSizeBinary(_) =>
sort_fixed_size_binary(array.as_fixed_size_binary(), v, n, options, limit),
+ DataType::List(_) => sort_list(array.as_list::<i32>(), v, n, options,
limit)?,
+ DataType::LargeList(_) => sort_list(array.as_list::<i64>(), v, n,
options, limit)?,
+ DataType::FixedSizeList(_, _) =>
sort_fixed_size_list(array.as_fixed_size_list(), v, n, options, limit)?,
+ DataType::Dictionary(_, _) => downcast_dictionary_array!{
+ array => sort_dictionary(array, v, n, options, limit)?,
+ _ => unreachable!()
}
- DataType::Duration(TimeUnit::Millisecond) => {
- sort_primitive::<DurationMillisecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Microsecond) => {
- sort_primitive::<DurationMicrosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Nanosecond) => {
- sort_primitive::<DurationNanosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Utf8 => sort_string::<i32>(values, v, n, &options, limit),
- DataType::LargeUtf8 => sort_string::<i64>(values, v, n, &options,
limit),
- DataType::List(field) | DataType::FixedSizeList(field, _) => {
- match field.data_type() {
- DataType::Int8 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Int16 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Int32 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Int64 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt8 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt16 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt32 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt64 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Float16 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Float32 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Float64 => sort_list::<i32>(values, v, n, &options,
limit),
- t => {
- return Err(ArrowError::ComputeError(format!(
- "Sort not supported for list type {t:?}"
- )));
- }
- }
- }
- DataType::LargeList(field) => match field.data_type() {
- DataType::Int8 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::Int16 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::Int32 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::Int64 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::UInt8 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::UInt16 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::UInt32 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::UInt64 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::Float16 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::Float32 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::Float64 => sort_list::<i64>(values, v, n, &options,
limit),
- t => {
- return Err(ArrowError::ComputeError(format!(
- "Sort not supported for list type {t:?}"
- )));
- }
- },
- DataType::Dictionary(_, _) => {
- let value_null_first = if options.descending {
- // When sorting dictionary in descending order, we take
inverse of of null ordering
- // when sorting the values. Because if `nulls_first` is true,
null must be in front
- // of non-null value. As we take the sorted order of value
array to sort dictionary
- // keys, these null values will be treated as smallest ones
and be sorted to the end
- // of sorted result. So we set `nulls_first` to false when
sorting dictionary value
- // array to make them as largest ones, then null values will
be put at the beginning
- // of sorted dictionary result.
- !options.nulls_first
- } else {
- options.nulls_first
- };
- let value_options = Some(SortOptions {
- descending: false,
- nulls_first: value_null_first,
- });
- downcast_dictionary_array! {
- values => {
- let dict_values = values.values();
- let sorted_value_indices = sort_to_indices(dict_values,
value_options, None)?;
- let rank = sorted_rank(&sorted_value_indices);
- sort_dictionary(values, &rank, v, n, options, limit)
- }
- _ => unreachable!(),
- }
- }
- DataType::Binary | DataType::FixedSizeBinary(_) => {
- sort_binary::<i32>(values, v, n, &options, limit)
- }
- DataType::LargeBinary => sort_binary::<i64>(values, v, n, &options,
limit),
DataType::RunEndEncoded(run_ends_field, _) => match
run_ends_field.data_type() {
- DataType::Int16 => sort_run_to_indices::<Int16Type>(values,
&options, limit),
- DataType::Int32 => sort_run_to_indices::<Int32Type>(values,
&options, limit),
- DataType::Int64 => sort_run_to_indices::<Int64Type>(values,
&options, limit),
+ DataType::Int16 => sort_run_to_indices::<Int16Type>(array,
options, limit),
+ DataType::Int32 => sort_run_to_indices::<Int32Type>(array,
options, limit),
+ DataType::Int64 => sort_run_to_indices::<Int64Type>(array,
options, limit),
dt => {
return Err(ArrowError::ComputeError(format!(
"Invalid run end data type: {dt}"
@@ -422,147 +237,76 @@ pub fn sort_to_indices(
})
}
-/// Sort boolean values
-///
-/// when a limit is present, the sort is pair-comparison based as k-select
might be more efficient,
-/// when the limit is absent, binary partition is used to speed up (which is
linear).
-///
-/// TODO maybe partition_validity call can be eliminated in this case
-/// and [tri-color
sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem)
-/// can be used instead.
fn sort_boolean(
- values: &dyn Array,
+ values: &BooleanArray,
value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
- options: &SortOptions,
+ null_indices: Vec<u32>,
+ options: SortOptions,
limit: Option<usize>,
) -> UInt32Array {
- let values = values
- .as_any()
- .downcast_ref::<BooleanArray>()
- .expect("Unable to downcast to boolean array");
- let descending = options.descending;
-
- let valids_len = value_indices.len();
- let nulls_len = null_indices.len();
-
- let mut len = values.len();
- let valids = if let Some(limit) = limit {
- len = limit.min(len);
- // create tuples that are used for sorting
- let mut valids = value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .collect::<Vec<(u32, bool)>>();
-
- sort_valids(descending, &mut valids, len, cmp);
- valids
- } else {
- // when limit is not present, we have a better way than sorting: we
can just partition
- // the vec into [false..., true...] or [true..., false...] when
descending
- // TODO when https://github.com/rust-lang/rust/issues/62543 is merged
we can use partition_in_place
- let (mut a, b): (Vec<_>, Vec<_>) = value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .partition(|(_, value)| *value == descending);
- a.extend(b);
- if descending {
- null_indices.reverse();
- }
- a
- };
-
- let nulls = null_indices;
-
- // collect results directly into a buffer instead of a vec to avoid
another aligned allocation
- let result_capacity = len * std::mem::size_of::<u32>();
- let mut result = MutableBuffer::new(result_capacity);
- // sets len to capacity so we can access the whole buffer as a typed slice
- result.resize(result_capacity, 0);
- let result_slice: &mut [u32] = result.typed_data_mut();
-
- if options.nulls_first {
- let size = nulls_len.min(len);
- result_slice[0..size].copy_from_slice(&nulls[0..size]);
- if nulls_len < len {
- insert_valid_values(result_slice, nulls_len, &valids[0..len -
size]);
- }
- } else {
- // nulls last
- let size = valids.len().min(len);
- insert_valid_values(result_slice, 0, &valids[0..size]);
- if len > size {
- result_slice[valids_len..].copy_from_slice(&nulls[0..(len -
valids_len)]);
- }
- }
-
- let result_data = unsafe {
- ArrayData::new_unchecked(
- DataType::UInt32,
- len,
- Some(0),
- None,
- 0,
- vec![result.into()],
- vec![],
- )
- };
+ let mut valids = value_indices
+ .into_iter()
+ .map(|index| (index, values.value(index as usize)))
+ .collect::<Vec<(u32, bool)>>();
+ sort_impl(options, &mut valids, &null_indices, limit, |a, b|
a.cmp(&b)).into()
+}
- UInt32Array::from(result_data)
+fn sort_primitive<T: ArrowPrimitiveType>(
+ values: &PrimitiveArray<T>,
+ value_indices: Vec<u32>,
+ nulls: Vec<u32>,
+ options: SortOptions,
+ limit: Option<usize>,
+) -> UInt32Array {
+ let mut valids = value_indices
+ .into_iter()
+ .map(|index| (index, values.value(index as usize)))
+ .collect::<Vec<(u32, T::Native)>>();
+ sort_impl(options, &mut valids, &nulls, limit, T::Native::compare).into()
}
-/// Sort primitive values
-fn sort_primitive<T, F>(
- values: &dyn Array,
+fn sort_bytes<T: ByteArrayType>(
+ values: &GenericByteArray<T>,
value_indices: Vec<u32>,
- null_indices: Vec<u32>,
- cmp: F,
- options: &SortOptions,
+ nulls: Vec<u32>,
+ options: SortOptions,
limit: Option<usize>,
-) -> UInt32Array
-where
- T: ArrowPrimitiveType,
- T::Native: PartialOrd,
- F: Fn(T::Native, T::Native) -> Ordering,
-{
- // create tuples that are used for sorting
- let valids = {
- let values = values.as_primitive::<T>();
- value_indices
- .into_iter()
- .map(|index| (index, values.value(index as usize)))
- .collect::<Vec<(u32, T::Native)>>()
- };
- sort_primitive_inner(values.len(), null_indices, cmp, options, limit,
valids)
+) -> UInt32Array {
+ let mut valids = value_indices
+ .into_iter()
+ .map(|index| (index, values.value(index as usize).as_ref()))
+ .collect::<Vec<(u32, &[u8])>>();
+
+ sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
}
-/// Given a list of indices that yield a sorted order, returns the ordered
-/// rank of each index
-///
-/// e.g. [2, 4, 3, 1, 0] -> [4, 3, 0, 2, 1]
-fn sorted_rank(sorted_value_indices: &UInt32Array) -> Vec<u32> {
- assert_eq!(sorted_value_indices.null_count(), 0);
- let sorted_indices = sorted_value_indices.values();
- let mut out: Vec<_> = vec![0_u32; sorted_indices.len()];
- for (ix, val) in sorted_indices.iter().enumerate() {
- out[*val as usize] = ix as u32;
- }
- out
+fn sort_fixed_size_binary(
+ values: &FixedSizeBinaryArray,
+ value_indices: Vec<u32>,
+ nulls: Vec<u32>,
+ options: SortOptions,
+ limit: Option<usize>,
+) -> UInt32Array {
+ let mut valids = value_indices
+ .iter()
+ .copied()
+ .map(|index| (index, values.value(index as usize)))
+ .collect::<Vec<(u32, &[u8])>>();
+ sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
}
-/// Sort dictionary given the sorted rank of each key
fn sort_dictionary<K: ArrowDictionaryKeyType>(
dict: &DictionaryArray<K>,
- rank: &[u32],
value_indices: Vec<u32>,
null_indices: Vec<u32>,
options: SortOptions,
limit: Option<usize>,
-) -> UInt32Array {
+) -> Result<UInt32Array, ArrowError> {
let keys: &PrimitiveArray<K> = dict.keys();
+ let rank = child_rank(dict.values().as_ref(), options)?;
// create tuples that are used for sorting
- let valids = value_indices
+ let mut valids = value_indices
.into_iter()
.map(|index| {
let key: K::Native = keys.value(index as usize);
@@ -570,83 +314,100 @@ fn sort_dictionary<K: ArrowDictionaryKeyType>(
})
.collect::<Vec<(u32, u32)>>();
- sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, &options,
limit, valids)
+ Ok(sort_impl(options, &mut valids, &null_indices, limit, |a, b|
a.cmp(&b)).into())
}
-// sort is instantiated a lot so we only compile this inner version for each
native type
-fn sort_primitive_inner<T, F>(
- value_len: usize,
- nulls: Vec<u32>,
- cmp: F,
- options: &SortOptions,
+fn sort_list<O: OffsetSizeTrait>(
+ array: &GenericListArray<O>,
+ value_indices: Vec<u32>,
+ null_indices: Vec<u32>,
+ options: SortOptions,
limit: Option<usize>,
- mut valids: Vec<(u32, T)>,
-) -> UInt32Array
-where
- T: ArrowNativeType,
- T: PartialOrd,
- F: Fn(T, T) -> Ordering,
-{
- let valids_len = valids.len();
- let nulls_len = nulls.len();
- let mut len = value_len;
+) -> Result<UInt32Array, ArrowError> {
+ let rank = child_rank(array.values().as_ref(), options)?;
+ let offsets = array.value_offsets();
+ let mut valids = value_indices
+ .into_iter()
+ .map(|index| {
+ let end = offsets[index as usize + 1].as_usize();
+ let start = offsets[index as usize].as_usize();
+ (index, &rank[start..end])
+ })
+ .collect::<Vec<(u32, &[u32])>>();
+ Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into())
+}
- if let Some(limit) = limit {
- len = limit.min(len);
- }
+fn sort_fixed_size_list(
+ array: &FixedSizeListArray,
+ value_indices: Vec<u32>,
+ null_indices: Vec<u32>,
+ options: SortOptions,
+ limit: Option<usize>,
+) -> Result<UInt32Array, ArrowError> {
+ let rank = child_rank(array.values().as_ref(), options)?;
+ let size = array.value_length() as usize;
+ let mut valids = value_indices
+ .into_iter()
+ .map(|index| {
+ let start = index as usize * size;
+ (index, &rank[start..start + size])
+ })
+ .collect::<Vec<(u32, &[u32])>>();
+ Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into())
+}
- sort_valids(options.descending, &mut valids, len, cmp);
+#[inline(never)]
+fn sort_impl<T: ?Sized + Copy>(
+ options: SortOptions,
+ valids: &mut [(u32, T)],
+ nulls: &[u32],
+ limit: Option<usize>,
+ mut cmp: impl FnMut(T, T) -> Ordering,
+) -> Vec<u32> {
+ let v_limit = match (limit, options.nulls_first) {
+ (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()),
+ _ => valids.len(),
+ };
- // collect results directly into a buffer instead of a vec to avoid
another aligned allocation
- let result_capacity = len * std::mem::size_of::<u32>();
- let mut result = MutableBuffer::new(result_capacity);
- // sets len to capacity so we can access the whole buffer as a typed slice
- result.resize(result_capacity, 0);
- let result_slice: &mut [u32] = result.typed_data_mut();
+ match options.descending {
+ false => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1, b.1)),
+ true => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1,
b.1).reverse()),
+ }
- if options.nulls_first {
- let size = nulls_len.min(len);
- result_slice[0..size].copy_from_slice(&nulls[0..size]);
- if nulls_len < len {
- insert_valid_values(result_slice, nulls_len, &valids[0..len -
size]);
+ let len = valids.len() + nulls.len();
+ let limit = limit.unwrap_or(len).min(len);
+ let mut out = Vec::with_capacity(len);
+ match options.nulls_first {
+ true => {
+ out.extend_from_slice(&nulls[..nulls.len().min(limit)]);
+ let remaining = limit - out.len();
+ out.extend(valids.iter().map(|x| x.0).take(remaining));
}
- } else {
- // nulls last
- let size = valids.len().min(len);
- insert_valid_values(result_slice, 0, &valids[0..size]);
- if len > size {
- result_slice[valids_len..].copy_from_slice(&nulls[0..(len -
valids_len)]);
+ false => {
+ out.extend(valids.iter().map(|x| x.0).take(limit));
+ let remaining = limit - out.len();
+ out.extend_from_slice(&nulls[..remaining])
}
}
-
- let result_data = unsafe {
- ArrayData::new_unchecked(
- DataType::UInt32,
- len,
- Some(0),
- None,
- 0,
- vec![result.into()],
- vec![],
- )
- };
-
- UInt32Array::from(result_data)
+ out
}
-// insert valid and nan values in the correct order depending on the
descending flag
-fn insert_valid_values<T>(result_slice: &mut [u32], offset: usize, valids:
&[(u32, T)]) {
- let valids_len = valids.len();
- // helper to append the index part of the valid tuples
- let append_valids = move |dst_slice: &mut [u32]| {
- debug_assert_eq!(dst_slice.len(), valids_len);
- dst_slice
- .iter_mut()
- .zip(valids.iter())
- .for_each(|(dst, src)| *dst = src.0)
- };
+/// Computes the rank for a set of child values
+fn child_rank(values: &dyn Array, options: SortOptions) -> Result<Vec<u32>,
ArrowError> {
+ // If parent sort order is descending we need to invert the value of
nulls_first so that
+ // when the parent is sorted based on the produced ranks, nulls are still
ordered correctly
+ let value_options = Some(SortOptions {
+ descending: false,
+ nulls_first: options.nulls_first != options.descending,
+ });
- append_valids(&mut result_slice[offset..offset + valids.len()]);
+ let sorted_value_indices = sort_to_indices(values, value_options, None)?;
+ let sorted_indices = sorted_value_indices.values();
+ let mut out: Vec<_> = vec![0_u32; sorted_indices.len()];
+ for (ix, val) in sorted_indices.iter().enumerate() {
+ out[*val as usize] = ix as u32;
+ }
+ Ok(out)
}
// Sort run array and return sorted run array.
@@ -737,7 +498,7 @@ fn sort_run_downcasted<R: RunEndIndexType>(
// encoded back to run array.
fn sort_run_to_indices<R: RunEndIndexType>(
values: &dyn Array,
- options: &SortOptions,
+ options: SortOptions,
limit: Option<usize>,
) -> UInt32Array {
let run_array = values.as_any().downcast_ref::<RunArray<R>>().unwrap();
@@ -752,7 +513,7 @@ fn sort_run_to_indices<R: RunEndIndexType>(
let consume_runs = |run_length, logical_start| {
result.extend(logical_start as u32..(logical_start + run_length) as
u32);
};
- sort_run_inner(run_array, Some(*options), output_len, consume_runs);
+ sort_run_inner(run_array, Some(options), output_len, consume_runs);
UInt32Array::from(result)
}
@@ -834,200 +595,6 @@ where
(values_indices, run_values)
}
-/// Sort strings
-fn sort_string<Offset: OffsetSizeTrait>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array {
- let values = values
- .as_any()
- .downcast_ref::<GenericStringArray<Offset>>()
- .unwrap();
-
- sort_string_helper(
- values,
- value_indices,
- null_indices,
- options,
- limit,
- |array, idx| array.value(idx as usize),
- )
-}
-
-/// shared implementation between dictionary encoded and plain string arrays
-#[inline]
-fn sort_string_helper<'a, A: Array, F>(
- values: &'a A,
- value_indices: Vec<u32>,
- null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
- value_fn: F,
-) -> UInt32Array
-where
- F: Fn(&'a A, u32) -> &str,
-{
- let mut valids = value_indices
- .into_iter()
- .map(|index| (index, value_fn(values, index)))
- .collect::<Vec<(u32, &str)>>();
- let mut nulls = null_indices;
- let descending = options.descending;
- let mut len = values.len();
-
- if let Some(limit) = limit {
- len = limit.min(len);
- }
-
- sort_valids(descending, &mut valids, len, cmp);
- // collect the order of valid tuplies
- let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
-
- if options.nulls_first {
- nulls.append(&mut valid_indices);
- nulls.truncate(len);
- UInt32Array::from(nulls)
- } else {
- // no need to sort nulls as they are in the correct order already
- valid_indices.append(&mut nulls);
- valid_indices.truncate(len);
- UInt32Array::from(valid_indices)
- }
-}
-
-fn sort_list<S>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array
-where
- S: OffsetSizeTrait,
-{
- sort_list_inner::<S>(values, value_indices, null_indices, options, limit)
-}
-
-fn sort_list_inner<S>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array
-where
- S: OffsetSizeTrait,
-{
- let mut valids: Vec<(u32, ArrayRef)> = values
- .as_any()
- .downcast_ref::<FixedSizeListArray>()
- .map_or_else(
- || {
- let values = as_generic_list_array::<S>(values);
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- |values| {
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- );
-
- let mut len = values.len();
- let descending = options.descending;
-
- if let Some(limit) = limit {
- len = limit.min(len);
- }
- sort_valids_array(descending, &mut valids, &mut null_indices, len);
-
- let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
- if options.nulls_first {
- null_indices.append(&mut valid_indices);
- null_indices.truncate(len);
- UInt32Array::from(null_indices)
- } else {
- valid_indices.append(&mut null_indices);
- valid_indices.truncate(len);
- UInt32Array::from(valid_indices)
- }
-}
-
-fn sort_binary<S>(
- values: &dyn Array,
- value_indices: Vec<u32>,
- mut null_indices: Vec<u32>,
- options: &SortOptions,
- limit: Option<usize>,
-) -> UInt32Array
-where
- S: OffsetSizeTrait,
-{
- let mut valids: Vec<(u32, &[u8])> = values
- .as_any()
- .downcast_ref::<FixedSizeBinaryArray>()
- .map_or_else(
- || {
- let values = as_generic_binary_array::<S>(values);
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- |values| {
- value_indices
- .iter()
- .copied()
- .map(|index| (index, values.value(index as usize)))
- .collect()
- },
- );
-
- let mut len = values.len();
- let descending = options.descending;
-
- if let Some(limit) = limit {
- len = limit.min(len);
- }
-
- sort_valids(descending, &mut valids, len, cmp);
-
- let mut valid_indices: Vec<u32> = valids.iter().map(|tuple|
tuple.0).collect();
- if options.nulls_first {
- null_indices.append(&mut valid_indices);
- null_indices.truncate(len);
- UInt32Array::from(null_indices)
- } else {
- valid_indices.append(&mut null_indices);
- valid_indices.truncate(len);
- UInt32Array::from(valid_indices)
- }
-}
-
-/// Compare two `Array`s based on the ordering defined in [build_compare]
-fn cmp_array(a: &dyn Array, b: &dyn Array) -> Ordering {
- let cmp_op = build_compare(a, b).unwrap();
- let length = a.len().max(b.len());
-
- for i in 0..length {
- let result = cmp_op(i, i);
- if result != Ordering::Equal {
- return result;
- }
- }
- Ordering::Equal
-}
-
/// One column to be used in lexicographical sort
#[derive(Clone, Debug)]
pub struct SortColumn {
@@ -1146,8 +713,10 @@ pub fn partial_sort<T, F>(v: &mut [T], limit: usize, mut
is_less: F)
where
F: FnMut(&T, &T) -> Ordering,
{
- let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less);
- before.sort_unstable_by(is_less);
+ if let Some(n) = limit.checked_sub(1) {
+ let (before, _mid, _after) = v.select_nth_unstable_by(n, &mut is_less);
+ before.sort_unstable_by(is_less);
+ }
}
type LexicographicalCompareItem<'a> = (
@@ -1228,42 +797,6 @@ impl LexicographicalComparator<'_> {
}
}
-fn sort_valids<T>(
- descending: bool,
- valids: &mut [(u32, T)],
- len: usize,
- mut cmp: impl FnMut(T, T) -> Ordering,
-) where
- T: ?Sized + Copy,
-{
- let valids_len = valids.len();
- if !descending {
- sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1));
- } else {
- sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1,
b.1).reverse());
- }
-}
-
-fn sort_valids_array<T>(
- descending: bool,
- valids: &mut [(u32, ArrayRef)],
- nulls: &mut [T],
- len: usize,
-) {
- let valids_len = valids.len();
- if !descending {
- sort_unstable_by(valids, len.min(valids_len), |a, b| {
- cmp_array(a.1.as_ref(), b.1.as_ref())
- });
- } else {
- sort_unstable_by(valids, len.min(valids_len), |a, b| {
- cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
- });
- // reverse to keep a stable ordering
- nulls.reverse();
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -1980,7 +1513,7 @@ mod tests {
nulls_first: false,
}),
None,
- vec![2, 3, 1, 4, 5, 0],
+ vec![2, 3, 1, 4, 0, 5],
);
// boolean, descending, nulls first
@@ -1991,7 +1524,7 @@ mod tests {
nulls_first: true,
}),
None,
- vec![5, 0, 2, 3, 1, 4],
+ vec![0, 5, 2, 3, 1, 4],
);
// boolean, descending, nulls first, limit
diff --git a/arrow/benches/sort_kernel.rs b/arrow/benches/sort_kernel.rs
index 3a3ce4462d..63e10e0528 100644
--- a/arrow/benches/sort_kernel.rs
+++ b/arrow/benches/sort_kernel.rs
@@ -67,23 +67,37 @@ fn bench_sort_to_indices(array: &dyn Array, limit:
Option<usize>) {
fn add_benchmark(c: &mut Criterion) {
let arr = create_primitive_array::<Int32Type>(2usize.pow(10), 0.0);
- c.bench_function("sort i64 2^10", |b| b.iter(|| bench_sort(&arr)));
-
- let arr = create_primitive_array::<Int32Type>(2usize.pow(12), 0.5);
- c.bench_function("sort i64 2^12", |b| b.iter(|| bench_sort(&arr)));
+ c.bench_function("sort i32 2^10", |b| b.iter(|| bench_sort(&arr)));
+ c.bench_function("sort i32 to indices 2^10", |b| {
+ b.iter(|| bench_sort_to_indices(&arr, None))
+ });
let arr = create_primitive_array::<Int32Type>(2usize.pow(12), 0.0);
- c.bench_function("sort i64 nulls 2^10", |b| b.iter(|| bench_sort(&arr)));
+ c.bench_function("sort i32 2^12", |b| b.iter(|| bench_sort(&arr)));
+ c.bench_function("sort i32 to indices 2^12", |b| {
+ b.iter(|| bench_sort_to_indices(&arr, None))
+ });
+
+ let arr = create_primitive_array::<Int32Type>(2usize.pow(10), 0.5);
+ c.bench_function("sort i32 nulls 2^10", |b| b.iter(|| bench_sort(&arr)));
+ c.bench_function("sort i32 nulls to indices 2^10", |b| {
+ b.iter(|| bench_sort_to_indices(&arr, None))
+ });
let arr = create_primitive_array::<Int32Type>(2usize.pow(12), 0.5);
- c.bench_function("sort i64 nulls 2^12", |b| b.iter(|| bench_sort(&arr)));
+ c.bench_function("sort i32 nulls 2^12", |b| b.iter(|| bench_sort(&arr)));
+ c.bench_function("sort i32 nulls to indices 2^12", |b| {
+ b.iter(|| bench_sort_to_indices(&arr, None))
+ });
let arr = create_f32_array(2_usize.pow(12), false);
+ c.bench_function("sort f32 2^12", |b| b.iter(|| bench_sort(&arr)));
c.bench_function("sort f32 to indices 2^12", |b| {
b.iter(|| bench_sort_to_indices(&arr, None))
});
let arr = create_f32_array(2usize.pow(12), true);
+ c.bench_function("sort f32 nulls 2^12", |b| b.iter(|| bench_sort(&arr)));
c.bench_function("sort f32 nulls to indices 2^12", |b| {
b.iter(|| bench_sort_to_indices(&arr, None))
});