tustvold commented on code in PR #4613:
URL: https://github.com/apache/arrow-rs/pull/4613#discussion_r1280464808
##########
arrow-ord/src/sort.rs:
##########
@@ -204,210 +196,33 @@ fn partition_validity(array: &dyn Array) -> (Vec<u32>,
Vec<u32>) {
/// For floating point arrays any NaN values are considered to be greater than
any other non-null value.
/// `limit` is an option for [partial_sort].
pub fn sort_to_indices(
- values: &dyn Array,
+ array: &dyn Array,
options: Option<SortOptions>,
limit: Option<usize>,
) -> Result<UInt32Array, ArrowError> {
let options = options.unwrap_or_default();
- let (v, n) = partition_validity(values);
-
- Ok(match values.data_type() {
- DataType::Decimal128(_, _) => {
- sort_primitive::<Decimal128Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Decimal256(_, _) => {
- sort_primitive::<Decimal256Type, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Boolean => sort_boolean(values, v, n, &options, limit),
- DataType::Int8 => {
- sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int16 => {
- sort_primitive::<Int16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int32 => {
- sort_primitive::<Int32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Int64 => {
- sort_primitive::<Int64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt8 => {
- sort_primitive::<UInt8Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt16 => {
- sort_primitive::<UInt16Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt32 => {
- sort_primitive::<UInt32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::UInt64 => {
- sort_primitive::<UInt64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Float16 => sort_primitive::<Float16Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float32 => sort_primitive::<Float32Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Float64 => sort_primitive::<Float64Type, _>(
- values,
- v,
- n,
- |x, y| x.total_cmp(&y),
- &options,
- limit,
- ),
- DataType::Date32 => {
- sort_primitive::<Date32Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Date64 => {
- sort_primitive::<Date64Type, _>(values, v, n, cmp, &options, limit)
- }
- DataType::Time32(TimeUnit::Second) => {
- sort_primitive::<Time32SecondType, _>(values, v, n, cmp, &options,
limit)
- }
- DataType::Time32(TimeUnit::Millisecond) => {
- sort_primitive::<Time32MillisecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Microsecond) => {
- sort_primitive::<Time64MicrosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Time64(TimeUnit::Nanosecond) => {
- sort_primitive::<Time64NanosecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- sort_primitive::<TimestampSecondType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- sort_primitive::<TimestampMillisecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- sort_primitive::<TimestampMicrosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- sort_primitive::<TimestampNanosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Interval(IntervalUnit::YearMonth) => {
- sort_primitive::<IntervalYearMonthType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::DayTime) => {
- sort_primitive::<IntervalDayTimeType, _>(values, v, n, cmp,
&options, limit)
- }
- DataType::Interval(IntervalUnit::MonthDayNano) => {
- sort_primitive::<IntervalMonthDayNanoType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Second) => {
- sort_primitive::<DurationSecondType, _>(values, v, n, cmp,
&options, limit)
+ let (v, n) = partition_validity(array);
+
+ Ok(downcast_primitive_array! {
+ array => sort_primitive(array, v, n, options, limit),
+ DataType::Boolean => sort_boolean(array.as_boolean(), v, n, options,
limit),
+ DataType::Utf8 => sort_bytes(array.as_string::<i32>(), v, n, options,
limit),
+ DataType::LargeUtf8 => sort_bytes(array.as_string::<i64>(), v, n,
options, limit),
+ DataType::Binary => sort_bytes(array.as_binary::<i32>(), v, n,
options, limit),
+ DataType::LargeBinary => sort_bytes(array.as_binary::<i64>(), v, n,
options, limit),
+ DataType::FixedSizeBinary(_) =>
sort_fixed_size_binary(array.as_fixed_size_binary(), v, n, options, limit),
+ DataType::List(_) => sort_list(array.as_list::<i32>(), v, n, options,
limit)?,
+ DataType::LargeList(_) => sort_list(array.as_list::<i64>(), v, n,
options, limit)?,
+ DataType::FixedSizeList(_, _) =>
sort_fixed_size_list(array.as_fixed_size_list(), v, n, options, limit)?,
+ DataType::Dictionary(_, _) => downcast_dictionary_array!{
+ array => sort_dictionary(array, v, n, options, limit)?,
+ _ => unreachable!()
}
- DataType::Duration(TimeUnit::Millisecond) => {
- sort_primitive::<DurationMillisecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Microsecond) => {
- sort_primitive::<DurationMicrosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Duration(TimeUnit::Nanosecond) => {
- sort_primitive::<DurationNanosecondType, _>(
- values, v, n, cmp, &options, limit,
- )
- }
- DataType::Utf8 => sort_string::<i32>(values, v, n, &options, limit),
- DataType::LargeUtf8 => sort_string::<i64>(values, v, n, &options,
limit),
- DataType::List(field) | DataType::FixedSizeList(field, _) => {
- match field.data_type() {
- DataType::Int8 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Int16 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Int32 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Int64 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt8 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt16 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt32 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::UInt64 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Float16 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Float32 => sort_list::<i32>(values, v, n, &options,
limit),
- DataType::Float64 => sort_list::<i32>(values, v, n, &options,
limit),
- t => {
- return Err(ArrowError::ComputeError(format!(
- "Sort not supported for list type {t:?}"
- )));
- }
- }
- }
- DataType::LargeList(field) => match field.data_type() {
- DataType::Int8 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::Int16 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::Int32 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::Int64 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::UInt8 => sort_list::<i64>(values, v, n, &options, limit),
- DataType::UInt16 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::UInt32 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::UInt64 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::Float16 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::Float32 => sort_list::<i64>(values, v, n, &options,
limit),
- DataType::Float64 => sort_list::<i64>(values, v, n, &options,
limit),
- t => {
- return Err(ArrowError::ComputeError(format!(
- "Sort not supported for list type {t:?}"
- )));
- }
- },
- DataType::Dictionary(_, _) => {
- let value_null_first = if options.descending {
Review Comment:
This logic is moved into a child_rank function
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]