This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new a81da6c89c Cleanup sort (#4613)
a81da6c89c is described below

commit a81da6c89c68507cfb0b37a057dcecd7ba582d9b
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Aug 4 15:28:18 2023 +0100

    Cleanup sort (#4613)
    
    * Cleanup sort
    
    * Add inline
    
    * Further cleanup
    
    * Further sort benchmark fixes
---
 arrow-array/src/cast.rs      |  17 +
 arrow-ord/src/sort.rs        | 785 +++++++++----------------------------------
 arrow/benches/sort_kernel.rs |  26 +-
 3 files changed, 196 insertions(+), 632 deletions(-)

diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs
index bee8823d1f..66b40d5b8e 100644
--- a/arrow-array/src/cast.rs
+++ b/arrow-array/src/cast.rs
@@ -799,6 +799,15 @@ pub trait AsArray: private::Sealed {
         self.as_list_opt().expect("list array")
     }
 
+    /// Downcast this to a [`FixedSizeBinaryArray`] returning `None` if not 
possible
+    fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray>;
+
+    /// Downcast this to a [`FixedSizeBinaryArray`] panicking if not possible
+    fn as_fixed_size_binary(&self) -> &FixedSizeBinaryArray {
+        self.as_fixed_size_binary_opt()
+            .expect("fixed size binary array")
+    }
+
     /// Downcast this to a [`FixedSizeListArray`] returning `None` if not 
possible
     fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray>;
 
@@ -848,6 +857,10 @@ impl AsArray for dyn Array + '_ {
         self.as_any().downcast_ref()
     }
 
+    fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> {
+        self.as_any().downcast_ref()
+    }
+
     fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> {
         self.as_any().downcast_ref()
     }
@@ -885,6 +898,10 @@ impl AsArray for ArrayRef {
         self.as_ref().as_list_opt()
     }
 
+    fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> {
+        self.as_ref().as_fixed_size_binary_opt()
+    }
+
     fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> {
         self.as_ref().as_fixed_size_list_opt()
     }
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index c3e9e26ec0..648a7d7afc 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -23,10 +23,9 @@ use arrow_array::cast::*;
 use arrow_array::types::*;
 use arrow_array::*;
 use arrow_buffer::BooleanBufferBuilder;
-use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer};
-use arrow_data::ArrayData;
+use arrow_buffer::{ArrowNativeType, NullBuffer};
 use arrow_data::ArrayDataBuilder;
-use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit};
+use arrow_schema::{ArrowError, DataType};
 use arrow_select::take::take;
 use std::cmp::Ordering;
 use std::sync::Arc;
@@ -181,13 +180,6 @@ where
     }
 }
 
-fn cmp<T>(l: T, r: T) -> Ordering
-where
-    T: Ord,
-{
-    l.cmp(&r)
-}
-
 // partition indices into valid and null indices
 fn partition_validity(array: &dyn Array) -> (Vec<u32>, Vec<u32>) {
     match array.null_count() {
@@ -204,210 +196,33 @@ fn partition_validity(array: &dyn Array) -> (Vec<u32>, 
Vec<u32>) {
 /// For floating point arrays any NaN values are considered to be greater than 
any other non-null value.
 /// `limit` is an option for [partial_sort].
 pub fn sort_to_indices(
-    values: &dyn Array,
+    array: &dyn Array,
     options: Option<SortOptions>,
     limit: Option<usize>,
 ) -> Result<UInt32Array, ArrowError> {
     let options = options.unwrap_or_default();
 
-    let (v, n) = partition_validity(values);
-
-    Ok(match values.data_type() {
-        DataType::Decimal128(_, _) => {
-            sort_primitive::<Decimal128Type, _>(values, v, n, cmp, &options, 
limit)
-        }
-        DataType::Decimal256(_, _) => {
-            sort_primitive::<Decimal256Type, _>(values, v, n, cmp, &options, 
limit)
-        }
-        DataType::Boolean => sort_boolean(values, v, n, &options, limit),
-        DataType::Int8 => {
-            sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::Int16 => {
-            sort_primitive::<Int16Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::Int32 => {
-            sort_primitive::<Int32Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::Int64 => {
-            sort_primitive::<Int64Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::UInt8 => {
-            sort_primitive::<UInt8Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::UInt16 => {
-            sort_primitive::<UInt16Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::UInt32 => {
-            sort_primitive::<UInt32Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::UInt64 => {
-            sort_primitive::<UInt64Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::Float16 => sort_primitive::<Float16Type, _>(
-            values,
-            v,
-            n,
-            |x, y| x.total_cmp(&y),
-            &options,
-            limit,
-        ),
-        DataType::Float32 => sort_primitive::<Float32Type, _>(
-            values,
-            v,
-            n,
-            |x, y| x.total_cmp(&y),
-            &options,
-            limit,
-        ),
-        DataType::Float64 => sort_primitive::<Float64Type, _>(
-            values,
-            v,
-            n,
-            |x, y| x.total_cmp(&y),
-            &options,
-            limit,
-        ),
-        DataType::Date32 => {
-            sort_primitive::<Date32Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::Date64 => {
-            sort_primitive::<Date64Type, _>(values, v, n, cmp, &options, limit)
-        }
-        DataType::Time32(TimeUnit::Second) => {
-            sort_primitive::<Time32SecondType, _>(values, v, n, cmp, &options, 
limit)
-        }
-        DataType::Time32(TimeUnit::Millisecond) => {
-            sort_primitive::<Time32MillisecondType, _>(values, v, n, cmp, 
&options, limit)
-        }
-        DataType::Time64(TimeUnit::Microsecond) => {
-            sort_primitive::<Time64MicrosecondType, _>(values, v, n, cmp, 
&options, limit)
-        }
-        DataType::Time64(TimeUnit::Nanosecond) => {
-            sort_primitive::<Time64NanosecondType, _>(values, v, n, cmp, 
&options, limit)
-        }
-        DataType::Timestamp(TimeUnit::Second, _) => {
-            sort_primitive::<TimestampSecondType, _>(values, v, n, cmp, 
&options, limit)
-        }
-        DataType::Timestamp(TimeUnit::Millisecond, _) => {
-            sort_primitive::<TimestampMillisecondType, _>(
-                values, v, n, cmp, &options, limit,
-            )
-        }
-        DataType::Timestamp(TimeUnit::Microsecond, _) => {
-            sort_primitive::<TimestampMicrosecondType, _>(
-                values, v, n, cmp, &options, limit,
-            )
-        }
-        DataType::Timestamp(TimeUnit::Nanosecond, _) => {
-            sort_primitive::<TimestampNanosecondType, _>(
-                values, v, n, cmp, &options, limit,
-            )
-        }
-        DataType::Interval(IntervalUnit::YearMonth) => {
-            sort_primitive::<IntervalYearMonthType, _>(values, v, n, cmp, 
&options, limit)
-        }
-        DataType::Interval(IntervalUnit::DayTime) => {
-            sort_primitive::<IntervalDayTimeType, _>(values, v, n, cmp, 
&options, limit)
-        }
-        DataType::Interval(IntervalUnit::MonthDayNano) => {
-            sort_primitive::<IntervalMonthDayNanoType, _>(
-                values, v, n, cmp, &options, limit,
-            )
-        }
-        DataType::Duration(TimeUnit::Second) => {
-            sort_primitive::<DurationSecondType, _>(values, v, n, cmp, 
&options, limit)
+    let (v, n) = partition_validity(array);
+
+    Ok(downcast_primitive_array! {
+        array => sort_primitive(array, v, n, options, limit),
+        DataType::Boolean => sort_boolean(array.as_boolean(), v, n, options, 
limit),
+        DataType::Utf8 => sort_bytes(array.as_string::<i32>(), v, n, options, 
limit),
+        DataType::LargeUtf8 => sort_bytes(array.as_string::<i64>(), v, n, 
options, limit),
+        DataType::Binary => sort_bytes(array.as_binary::<i32>(), v, n, 
options, limit),
+        DataType::LargeBinary => sort_bytes(array.as_binary::<i64>(), v, n, 
options, limit),
+        DataType::FixedSizeBinary(_) => 
sort_fixed_size_binary(array.as_fixed_size_binary(), v, n, options, limit),
+        DataType::List(_) => sort_list(array.as_list::<i32>(), v, n, options, 
limit)?,
+        DataType::LargeList(_) => sort_list(array.as_list::<i64>(), v, n, 
options, limit)?,
+        DataType::FixedSizeList(_, _) => 
sort_fixed_size_list(array.as_fixed_size_list(), v, n, options, limit)?,
+        DataType::Dictionary(_, _) => downcast_dictionary_array!{
+            array => sort_dictionary(array, v, n, options, limit)?,
+            _ => unreachable!()
         }
-        DataType::Duration(TimeUnit::Millisecond) => {
-            sort_primitive::<DurationMillisecondType, _>(
-                values, v, n, cmp, &options, limit,
-            )
-        }
-        DataType::Duration(TimeUnit::Microsecond) => {
-            sort_primitive::<DurationMicrosecondType, _>(
-                values, v, n, cmp, &options, limit,
-            )
-        }
-        DataType::Duration(TimeUnit::Nanosecond) => {
-            sort_primitive::<DurationNanosecondType, _>(
-                values, v, n, cmp, &options, limit,
-            )
-        }
-        DataType::Utf8 => sort_string::<i32>(values, v, n, &options, limit),
-        DataType::LargeUtf8 => sort_string::<i64>(values, v, n, &options, 
limit),
-        DataType::List(field) | DataType::FixedSizeList(field, _) => {
-            match field.data_type() {
-                DataType::Int8 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::Int16 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::Int32 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::Int64 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::UInt8 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::UInt16 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::UInt32 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::UInt64 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::Float16 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::Float32 => sort_list::<i32>(values, v, n, &options, 
limit),
-                DataType::Float64 => sort_list::<i32>(values, v, n, &options, 
limit),
-                t => {
-                    return Err(ArrowError::ComputeError(format!(
-                        "Sort not supported for list type {t:?}"
-                    )));
-                }
-            }
-        }
-        DataType::LargeList(field) => match field.data_type() {
-            DataType::Int8 => sort_list::<i64>(values, v, n, &options, limit),
-            DataType::Int16 => sort_list::<i64>(values, v, n, &options, limit),
-            DataType::Int32 => sort_list::<i64>(values, v, n, &options, limit),
-            DataType::Int64 => sort_list::<i64>(values, v, n, &options, limit),
-            DataType::UInt8 => sort_list::<i64>(values, v, n, &options, limit),
-            DataType::UInt16 => sort_list::<i64>(values, v, n, &options, 
limit),
-            DataType::UInt32 => sort_list::<i64>(values, v, n, &options, 
limit),
-            DataType::UInt64 => sort_list::<i64>(values, v, n, &options, 
limit),
-            DataType::Float16 => sort_list::<i64>(values, v, n, &options, 
limit),
-            DataType::Float32 => sort_list::<i64>(values, v, n, &options, 
limit),
-            DataType::Float64 => sort_list::<i64>(values, v, n, &options, 
limit),
-            t => {
-                return Err(ArrowError::ComputeError(format!(
-                    "Sort not supported for list type {t:?}"
-                )));
-            }
-        },
-        DataType::Dictionary(_, _) => {
-            let value_null_first = if options.descending {
-                // When sorting dictionary in descending order, we take 
inverse of of null ordering
-                // when sorting the values. Because if `nulls_first` is true, 
null must be in front
-                // of non-null value. As we take the sorted order of value 
array to sort dictionary
-                // keys, these null values will be treated as smallest ones 
and be sorted to the end
-                // of sorted result. So we set `nulls_first` to false when 
sorting dictionary value
-                // array to make them as largest ones, then null values will 
be put at the beginning
-                // of sorted dictionary result.
-                !options.nulls_first
-            } else {
-                options.nulls_first
-            };
-            let value_options = Some(SortOptions {
-                descending: false,
-                nulls_first: value_null_first,
-            });
-            downcast_dictionary_array! {
-                values => {
-                    let dict_values = values.values();
-                    let sorted_value_indices = sort_to_indices(dict_values, 
value_options, None)?;
-                    let rank = sorted_rank(&sorted_value_indices);
-                    sort_dictionary(values, &rank, v, n, options, limit)
-                }
-                _ => unreachable!(),
-            }
-        }
-        DataType::Binary | DataType::FixedSizeBinary(_) => {
-            sort_binary::<i32>(values, v, n, &options, limit)
-        }
-        DataType::LargeBinary => sort_binary::<i64>(values, v, n, &options, 
limit),
         DataType::RunEndEncoded(run_ends_field, _) => match 
run_ends_field.data_type() {
-            DataType::Int16 => sort_run_to_indices::<Int16Type>(values, 
&options, limit),
-            DataType::Int32 => sort_run_to_indices::<Int32Type>(values, 
&options, limit),
-            DataType::Int64 => sort_run_to_indices::<Int64Type>(values, 
&options, limit),
+            DataType::Int16 => sort_run_to_indices::<Int16Type>(array, 
options, limit),
+            DataType::Int32 => sort_run_to_indices::<Int32Type>(array, 
options, limit),
+            DataType::Int64 => sort_run_to_indices::<Int64Type>(array, 
options, limit),
             dt => {
                 return Err(ArrowError::ComputeError(format!(
                     "Invalid run end data type: {dt}"
@@ -422,147 +237,76 @@ pub fn sort_to_indices(
     })
 }
 
-/// Sort boolean values
-///
-/// when a limit is present, the sort is pair-comparison based as k-select 
might be more efficient,
-/// when the limit is absent, binary partition is used to speed up (which is 
linear).
-///
-/// TODO maybe partition_validity call can be eliminated in this case
-/// and [tri-color 
sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem)
-/// can be used instead.
 fn sort_boolean(
-    values: &dyn Array,
+    values: &BooleanArray,
     value_indices: Vec<u32>,
-    mut null_indices: Vec<u32>,
-    options: &SortOptions,
+    null_indices: Vec<u32>,
+    options: SortOptions,
     limit: Option<usize>,
 ) -> UInt32Array {
-    let values = values
-        .as_any()
-        .downcast_ref::<BooleanArray>()
-        .expect("Unable to downcast to boolean array");
-    let descending = options.descending;
-
-    let valids_len = value_indices.len();
-    let nulls_len = null_indices.len();
-
-    let mut len = values.len();
-    let valids = if let Some(limit) = limit {
-        len = limit.min(len);
-        // create tuples that are used for sorting
-        let mut valids = value_indices
-            .into_iter()
-            .map(|index| (index, values.value(index as usize)))
-            .collect::<Vec<(u32, bool)>>();
-
-        sort_valids(descending, &mut valids, len, cmp);
-        valids
-    } else {
-        // when limit is not present, we have a better way than sorting: we 
can just partition
-        // the vec into [false..., true...] or [true..., false...] when 
descending
-        // TODO when https://github.com/rust-lang/rust/issues/62543 is merged 
we can use partition_in_place
-        let (mut a, b): (Vec<_>, Vec<_>) = value_indices
-            .into_iter()
-            .map(|index| (index, values.value(index as usize)))
-            .partition(|(_, value)| *value == descending);
-        a.extend(b);
-        if descending {
-            null_indices.reverse();
-        }
-        a
-    };
-
-    let nulls = null_indices;
-
-    // collect results directly into a buffer instead of a vec to avoid 
another aligned allocation
-    let result_capacity = len * std::mem::size_of::<u32>();
-    let mut result = MutableBuffer::new(result_capacity);
-    // sets len to capacity so we can access the whole buffer as a typed slice
-    result.resize(result_capacity, 0);
-    let result_slice: &mut [u32] = result.typed_data_mut();
-
-    if options.nulls_first {
-        let size = nulls_len.min(len);
-        result_slice[0..size].copy_from_slice(&nulls[0..size]);
-        if nulls_len < len {
-            insert_valid_values(result_slice, nulls_len, &valids[0..len - 
size]);
-        }
-    } else {
-        // nulls last
-        let size = valids.len().min(len);
-        insert_valid_values(result_slice, 0, &valids[0..size]);
-        if len > size {
-            result_slice[valids_len..].copy_from_slice(&nulls[0..(len - 
valids_len)]);
-        }
-    }
-
-    let result_data = unsafe {
-        ArrayData::new_unchecked(
-            DataType::UInt32,
-            len,
-            Some(0),
-            None,
-            0,
-            vec![result.into()],
-            vec![],
-        )
-    };
+    let mut valids = value_indices
+        .into_iter()
+        .map(|index| (index, values.value(index as usize)))
+        .collect::<Vec<(u32, bool)>>();
+    sort_impl(options, &mut valids, &null_indices, limit, |a, b| 
a.cmp(&b)).into()
+}
 
-    UInt32Array::from(result_data)
+fn sort_primitive<T: ArrowPrimitiveType>(
+    values: &PrimitiveArray<T>,
+    value_indices: Vec<u32>,
+    nulls: Vec<u32>,
+    options: SortOptions,
+    limit: Option<usize>,
+) -> UInt32Array {
+    let mut valids = value_indices
+        .into_iter()
+        .map(|index| (index, values.value(index as usize)))
+        .collect::<Vec<(u32, T::Native)>>();
+    sort_impl(options, &mut valids, &nulls, limit, T::Native::compare).into()
 }
 
-/// Sort primitive values
-fn sort_primitive<T, F>(
-    values: &dyn Array,
+fn sort_bytes<T: ByteArrayType>(
+    values: &GenericByteArray<T>,
     value_indices: Vec<u32>,
-    null_indices: Vec<u32>,
-    cmp: F,
-    options: &SortOptions,
+    nulls: Vec<u32>,
+    options: SortOptions,
     limit: Option<usize>,
-) -> UInt32Array
-where
-    T: ArrowPrimitiveType,
-    T::Native: PartialOrd,
-    F: Fn(T::Native, T::Native) -> Ordering,
-{
-    // create tuples that are used for sorting
-    let valids = {
-        let values = values.as_primitive::<T>();
-        value_indices
-            .into_iter()
-            .map(|index| (index, values.value(index as usize)))
-            .collect::<Vec<(u32, T::Native)>>()
-    };
-    sort_primitive_inner(values.len(), null_indices, cmp, options, limit, 
valids)
+) -> UInt32Array {
+    let mut valids = value_indices
+        .into_iter()
+        .map(|index| (index, values.value(index as usize).as_ref()))
+        .collect::<Vec<(u32, &[u8])>>();
+
+    sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
 }
 
-/// Given a list of indices that yield a sorted order, returns the ordered
-/// rank of each index
-///
-/// e.g. [2, 4, 3, 1, 0] -> [4, 3, 0, 2, 1]
-fn sorted_rank(sorted_value_indices: &UInt32Array) -> Vec<u32> {
-    assert_eq!(sorted_value_indices.null_count(), 0);
-    let sorted_indices = sorted_value_indices.values();
-    let mut out: Vec<_> = vec![0_u32; sorted_indices.len()];
-    for (ix, val) in sorted_indices.iter().enumerate() {
-        out[*val as usize] = ix as u32;
-    }
-    out
+fn sort_fixed_size_binary(
+    values: &FixedSizeBinaryArray,
+    value_indices: Vec<u32>,
+    nulls: Vec<u32>,
+    options: SortOptions,
+    limit: Option<usize>,
+) -> UInt32Array {
+    let mut valids = value_indices
+        .iter()
+        .copied()
+        .map(|index| (index, values.value(index as usize)))
+        .collect::<Vec<(u32, &[u8])>>();
+    sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into()
 }
 
-/// Sort dictionary given the sorted rank of each key
 fn sort_dictionary<K: ArrowDictionaryKeyType>(
     dict: &DictionaryArray<K>,
-    rank: &[u32],
     value_indices: Vec<u32>,
     null_indices: Vec<u32>,
     options: SortOptions,
     limit: Option<usize>,
-) -> UInt32Array {
+) -> Result<UInt32Array, ArrowError> {
     let keys: &PrimitiveArray<K> = dict.keys();
+    let rank = child_rank(dict.values().as_ref(), options)?;
 
     // create tuples that are used for sorting
-    let valids = value_indices
+    let mut valids = value_indices
         .into_iter()
         .map(|index| {
             let key: K::Native = keys.value(index as usize);
@@ -570,83 +314,100 @@ fn sort_dictionary<K: ArrowDictionaryKeyType>(
         })
         .collect::<Vec<(u32, u32)>>();
 
-    sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, &options, 
limit, valids)
+    Ok(sort_impl(options, &mut valids, &null_indices, limit, |a, b| 
a.cmp(&b)).into())
 }
 
-// sort is instantiated a lot so we only compile this inner version for each 
native type
-fn sort_primitive_inner<T, F>(
-    value_len: usize,
-    nulls: Vec<u32>,
-    cmp: F,
-    options: &SortOptions,
+fn sort_list<O: OffsetSizeTrait>(
+    array: &GenericListArray<O>,
+    value_indices: Vec<u32>,
+    null_indices: Vec<u32>,
+    options: SortOptions,
     limit: Option<usize>,
-    mut valids: Vec<(u32, T)>,
-) -> UInt32Array
-where
-    T: ArrowNativeType,
-    T: PartialOrd,
-    F: Fn(T, T) -> Ordering,
-{
-    let valids_len = valids.len();
-    let nulls_len = nulls.len();
-    let mut len = value_len;
+) -> Result<UInt32Array, ArrowError> {
+    let rank = child_rank(array.values().as_ref(), options)?;
+    let offsets = array.value_offsets();
+    let mut valids = value_indices
+        .into_iter()
+        .map(|index| {
+            let end = offsets[index as usize + 1].as_usize();
+            let start = offsets[index as usize].as_usize();
+            (index, &rank[start..end])
+        })
+        .collect::<Vec<(u32, &[u32])>>();
+    Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into())
+}
 
-    if let Some(limit) = limit {
-        len = limit.min(len);
-    }
+fn sort_fixed_size_list(
+    array: &FixedSizeListArray,
+    value_indices: Vec<u32>,
+    null_indices: Vec<u32>,
+    options: SortOptions,
+    limit: Option<usize>,
+) -> Result<UInt32Array, ArrowError> {
+    let rank = child_rank(array.values().as_ref(), options)?;
+    let size = array.value_length() as usize;
+    let mut valids = value_indices
+        .into_iter()
+        .map(|index| {
+            let start = index as usize * size;
+            (index, &rank[start..start + size])
+        })
+        .collect::<Vec<(u32, &[u32])>>();
+    Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into())
+}
 
-    sort_valids(options.descending, &mut valids, len, cmp);
+#[inline(never)]
+fn sort_impl<T: ?Sized + Copy>(
+    options: SortOptions,
+    valids: &mut [(u32, T)],
+    nulls: &[u32],
+    limit: Option<usize>,
+    mut cmp: impl FnMut(T, T) -> Ordering,
+) -> Vec<u32> {
+    let v_limit = match (limit, options.nulls_first) {
+        (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()),
+        _ => valids.len(),
+    };
 
-    // collect results directly into a buffer instead of a vec to avoid 
another aligned allocation
-    let result_capacity = len * std::mem::size_of::<u32>();
-    let mut result = MutableBuffer::new(result_capacity);
-    // sets len to capacity so we can access the whole buffer as a typed slice
-    result.resize(result_capacity, 0);
-    let result_slice: &mut [u32] = result.typed_data_mut();
+    match options.descending {
+        false => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1, b.1)),
+        true => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1, 
b.1).reverse()),
+    }
 
-    if options.nulls_first {
-        let size = nulls_len.min(len);
-        result_slice[0..size].copy_from_slice(&nulls[0..size]);
-        if nulls_len < len {
-            insert_valid_values(result_slice, nulls_len, &valids[0..len - 
size]);
+    let len = valids.len() + nulls.len();
+    let limit = limit.unwrap_or(len).min(len);
+    let mut out = Vec::with_capacity(len);
+    match options.nulls_first {
+        true => {
+            out.extend_from_slice(&nulls[..nulls.len().min(limit)]);
+            let remaining = limit - out.len();
+            out.extend(valids.iter().map(|x| x.0).take(remaining));
         }
-    } else {
-        // nulls last
-        let size = valids.len().min(len);
-        insert_valid_values(result_slice, 0, &valids[0..size]);
-        if len > size {
-            result_slice[valids_len..].copy_from_slice(&nulls[0..(len - 
valids_len)]);
+        false => {
+            out.extend(valids.iter().map(|x| x.0).take(limit));
+            let remaining = limit - out.len();
+            out.extend_from_slice(&nulls[..remaining])
         }
     }
-
-    let result_data = unsafe {
-        ArrayData::new_unchecked(
-            DataType::UInt32,
-            len,
-            Some(0),
-            None,
-            0,
-            vec![result.into()],
-            vec![],
-        )
-    };
-
-    UInt32Array::from(result_data)
+    out
 }
 
-// insert valid and nan values in the correct order depending on the 
descending flag
-fn insert_valid_values<T>(result_slice: &mut [u32], offset: usize, valids: 
&[(u32, T)]) {
-    let valids_len = valids.len();
-    // helper to append the index part of the valid tuples
-    let append_valids = move |dst_slice: &mut [u32]| {
-        debug_assert_eq!(dst_slice.len(), valids_len);
-        dst_slice
-            .iter_mut()
-            .zip(valids.iter())
-            .for_each(|(dst, src)| *dst = src.0)
-    };
+/// Computes the rank for a set of child values
+fn child_rank(values: &dyn Array, options: SortOptions) -> Result<Vec<u32>, 
ArrowError> {
+    // If parent sort order is descending we need to invert the value of 
nulls_first so that
+    // when the parent is sorted based on the produced ranks, nulls are still 
ordered correctly
+    let value_options = Some(SortOptions {
+        descending: false,
+        nulls_first: options.nulls_first != options.descending,
+    });
 
-    append_valids(&mut result_slice[offset..offset + valids.len()]);
+    let sorted_value_indices = sort_to_indices(values, value_options, None)?;
+    let sorted_indices = sorted_value_indices.values();
+    let mut out: Vec<_> = vec![0_u32; sorted_indices.len()];
+    for (ix, val) in sorted_indices.iter().enumerate() {
+        out[*val as usize] = ix as u32;
+    }
+    Ok(out)
 }
 
 // Sort run array and return sorted run array.
@@ -737,7 +498,7 @@ fn sort_run_downcasted<R: RunEndIndexType>(
 // encoded back to run array.
 fn sort_run_to_indices<R: RunEndIndexType>(
     values: &dyn Array,
-    options: &SortOptions,
+    options: SortOptions,
     limit: Option<usize>,
 ) -> UInt32Array {
     let run_array = values.as_any().downcast_ref::<RunArray<R>>().unwrap();
@@ -752,7 +513,7 @@ fn sort_run_to_indices<R: RunEndIndexType>(
     let consume_runs = |run_length, logical_start| {
         result.extend(logical_start as u32..(logical_start + run_length) as 
u32);
     };
-    sort_run_inner(run_array, Some(*options), output_len, consume_runs);
+    sort_run_inner(run_array, Some(options), output_len, consume_runs);
 
     UInt32Array::from(result)
 }
@@ -834,200 +595,6 @@ where
     (values_indices, run_values)
 }
 
-/// Sort strings
-fn sort_string<Offset: OffsetSizeTrait>(
-    values: &dyn Array,
-    value_indices: Vec<u32>,
-    null_indices: Vec<u32>,
-    options: &SortOptions,
-    limit: Option<usize>,
-) -> UInt32Array {
-    let values = values
-        .as_any()
-        .downcast_ref::<GenericStringArray<Offset>>()
-        .unwrap();
-
-    sort_string_helper(
-        values,
-        value_indices,
-        null_indices,
-        options,
-        limit,
-        |array, idx| array.value(idx as usize),
-    )
-}
-
-/// shared implementation between dictionary encoded and plain string arrays
-#[inline]
-fn sort_string_helper<'a, A: Array, F>(
-    values: &'a A,
-    value_indices: Vec<u32>,
-    null_indices: Vec<u32>,
-    options: &SortOptions,
-    limit: Option<usize>,
-    value_fn: F,
-) -> UInt32Array
-where
-    F: Fn(&'a A, u32) -> &str,
-{
-    let mut valids = value_indices
-        .into_iter()
-        .map(|index| (index, value_fn(values, index)))
-        .collect::<Vec<(u32, &str)>>();
-    let mut nulls = null_indices;
-    let descending = options.descending;
-    let mut len = values.len();
-
-    if let Some(limit) = limit {
-        len = limit.min(len);
-    }
-
-    sort_valids(descending, &mut valids, len, cmp);
-    // collect the order of valid tuplies
-    let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
-
-    if options.nulls_first {
-        nulls.append(&mut valid_indices);
-        nulls.truncate(len);
-        UInt32Array::from(nulls)
-    } else {
-        // no need to sort nulls as they are in the correct order already
-        valid_indices.append(&mut nulls);
-        valid_indices.truncate(len);
-        UInt32Array::from(valid_indices)
-    }
-}
-
-fn sort_list<S>(
-    values: &dyn Array,
-    value_indices: Vec<u32>,
-    null_indices: Vec<u32>,
-    options: &SortOptions,
-    limit: Option<usize>,
-) -> UInt32Array
-where
-    S: OffsetSizeTrait,
-{
-    sort_list_inner::<S>(values, value_indices, null_indices, options, limit)
-}
-
-fn sort_list_inner<S>(
-    values: &dyn Array,
-    value_indices: Vec<u32>,
-    mut null_indices: Vec<u32>,
-    options: &SortOptions,
-    limit: Option<usize>,
-) -> UInt32Array
-where
-    S: OffsetSizeTrait,
-{
-    let mut valids: Vec<(u32, ArrayRef)> = values
-        .as_any()
-        .downcast_ref::<FixedSizeListArray>()
-        .map_or_else(
-            || {
-                let values = as_generic_list_array::<S>(values);
-                value_indices
-                    .iter()
-                    .copied()
-                    .map(|index| (index, values.value(index as usize)))
-                    .collect()
-            },
-            |values| {
-                value_indices
-                    .iter()
-                    .copied()
-                    .map(|index| (index, values.value(index as usize)))
-                    .collect()
-            },
-        );
-
-    let mut len = values.len();
-    let descending = options.descending;
-
-    if let Some(limit) = limit {
-        len = limit.min(len);
-    }
-    sort_valids_array(descending, &mut valids, &mut null_indices, len);
-
-    let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
-    if options.nulls_first {
-        null_indices.append(&mut valid_indices);
-        null_indices.truncate(len);
-        UInt32Array::from(null_indices)
-    } else {
-        valid_indices.append(&mut null_indices);
-        valid_indices.truncate(len);
-        UInt32Array::from(valid_indices)
-    }
-}
-
-fn sort_binary<S>(
-    values: &dyn Array,
-    value_indices: Vec<u32>,
-    mut null_indices: Vec<u32>,
-    options: &SortOptions,
-    limit: Option<usize>,
-) -> UInt32Array
-where
-    S: OffsetSizeTrait,
-{
-    let mut valids: Vec<(u32, &[u8])> = values
-        .as_any()
-        .downcast_ref::<FixedSizeBinaryArray>()
-        .map_or_else(
-            || {
-                let values = as_generic_binary_array::<S>(values);
-                value_indices
-                    .iter()
-                    .copied()
-                    .map(|index| (index, values.value(index as usize)))
-                    .collect()
-            },
-            |values| {
-                value_indices
-                    .iter()
-                    .copied()
-                    .map(|index| (index, values.value(index as usize)))
-                    .collect()
-            },
-        );
-
-    let mut len = values.len();
-    let descending = options.descending;
-
-    if let Some(limit) = limit {
-        len = limit.min(len);
-    }
-
-    sort_valids(descending, &mut valids, len, cmp);
-
-    let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
-    if options.nulls_first {
-        null_indices.append(&mut valid_indices);
-        null_indices.truncate(len);
-        UInt32Array::from(null_indices)
-    } else {
-        valid_indices.append(&mut null_indices);
-        valid_indices.truncate(len);
-        UInt32Array::from(valid_indices)
-    }
-}
-
-/// Compare two `Array`s based on the ordering defined in [build_compare]
-fn cmp_array(a: &dyn Array, b: &dyn Array) -> Ordering {
-    let cmp_op = build_compare(a, b).unwrap();
-    let length = a.len().max(b.len());
-
-    for i in 0..length {
-        let result = cmp_op(i, i);
-        if result != Ordering::Equal {
-            return result;
-        }
-    }
-    Ordering::Equal
-}
-
 /// One column to be used in lexicographical sort
 #[derive(Clone, Debug)]
 pub struct SortColumn {
@@ -1146,8 +713,10 @@ pub fn partial_sort<T, F>(v: &mut [T], limit: usize, mut 
is_less: F)
 where
     F: FnMut(&T, &T) -> Ordering,
 {
-    let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less);
-    before.sort_unstable_by(is_less);
+    if let Some(n) = limit.checked_sub(1) {
+        let (before, _mid, _after) = v.select_nth_unstable_by(n, &mut is_less);
+        before.sort_unstable_by(is_less);
+    }
 }
 
 type LexicographicalCompareItem<'a> = (
@@ -1228,42 +797,6 @@ impl LexicographicalComparator<'_> {
     }
 }
 
-fn sort_valids<T>(
-    descending: bool,
-    valids: &mut [(u32, T)],
-    len: usize,
-    mut cmp: impl FnMut(T, T) -> Ordering,
-) where
-    T: ?Sized + Copy,
-{
-    let valids_len = valids.len();
-    if !descending {
-        sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1));
-    } else {
-        sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, 
b.1).reverse());
-    }
-}
-
-fn sort_valids_array<T>(
-    descending: bool,
-    valids: &mut [(u32, ArrayRef)],
-    nulls: &mut [T],
-    len: usize,
-) {
-    let valids_len = valids.len();
-    if !descending {
-        sort_unstable_by(valids, len.min(valids_len), |a, b| {
-            cmp_array(a.1.as_ref(), b.1.as_ref())
-        });
-    } else {
-        sort_unstable_by(valids, len.min(valids_len), |a, b| {
-            cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
-        });
-        // reverse to keep a stable ordering
-        nulls.reverse();
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -1980,7 +1513,7 @@ mod tests {
                 nulls_first: false,
             }),
             None,
-            vec![2, 3, 1, 4, 5, 0],
+            vec![2, 3, 1, 4, 0, 5],
         );
 
         // boolean, descending, nulls first
@@ -1991,7 +1524,7 @@ mod tests {
                 nulls_first: true,
             }),
             None,
-            vec![5, 0, 2, 3, 1, 4],
+            vec![0, 5, 2, 3, 1, 4],
         );
 
         // boolean, descending, nulls first, limit
diff --git a/arrow/benches/sort_kernel.rs b/arrow/benches/sort_kernel.rs
index 3a3ce4462d..63e10e0528 100644
--- a/arrow/benches/sort_kernel.rs
+++ b/arrow/benches/sort_kernel.rs
@@ -67,23 +67,37 @@ fn bench_sort_to_indices(array: &dyn Array, limit: 
Option<usize>) {
 
 fn add_benchmark(c: &mut Criterion) {
     let arr = create_primitive_array::<Int32Type>(2usize.pow(10), 0.0);
-    c.bench_function("sort i64 2^10", |b| b.iter(|| bench_sort(&arr)));
-
-    let arr = create_primitive_array::<Int32Type>(2usize.pow(12), 0.5);
-    c.bench_function("sort i64 2^12", |b| b.iter(|| bench_sort(&arr)));
+    c.bench_function("sort i32 2^10", |b| b.iter(|| bench_sort(&arr)));
+    c.bench_function("sort i32 to indices 2^10", |b| {
+        b.iter(|| bench_sort_to_indices(&arr, None))
+    });
 
     let arr = create_primitive_array::<Int32Type>(2usize.pow(12), 0.0);
-    c.bench_function("sort i64 nulls 2^10", |b| b.iter(|| bench_sort(&arr)));
+    c.bench_function("sort i32 2^12", |b| b.iter(|| bench_sort(&arr)));
+    c.bench_function("sort i32 to indices 2^12", |b| {
+        b.iter(|| bench_sort_to_indices(&arr, None))
+    });
+
+    let arr = create_primitive_array::<Int32Type>(2usize.pow(10), 0.5);
+    c.bench_function("sort i32 nulls 2^10", |b| b.iter(|| bench_sort(&arr)));
+    c.bench_function("sort i32 nulls to indices 2^10", |b| {
+        b.iter(|| bench_sort_to_indices(&arr, None))
+    });
 
     let arr = create_primitive_array::<Int32Type>(2usize.pow(12), 0.5);
-    c.bench_function("sort i64 nulls 2^12", |b| b.iter(|| bench_sort(&arr)));
+    c.bench_function("sort i32 nulls 2^12", |b| b.iter(|| bench_sort(&arr)));
+    c.bench_function("sort i32 nulls to indices 2^12", |b| {
+        b.iter(|| bench_sort_to_indices(&arr, None))
+    });
 
     let arr = create_f32_array(2_usize.pow(12), false);
+    c.bench_function("sort f32 2^12", |b| b.iter(|| bench_sort(&arr)));
     c.bench_function("sort f32 to indices 2^12", |b| {
         b.iter(|| bench_sort_to_indices(&arr, None))
     });
 
     let arr = create_f32_array(2usize.pow(12), true);
+    c.bench_function("sort f32 nulls 2^12", |b| b.iter(|| bench_sort(&arr)));
     c.bench_function("sort f32 nulls to indices 2^12", |b| {
         b.iter(|| bench_sort_to_indices(&arr, None))
     });


Reply via email to