Dandandan commented on code in PR #4473:
URL: https://github.com/apache/arrow-rs/pull/4473#discussion_r1251263331
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let primitive_values = values.as_primitive::<T>();
+
+ let result_capacity = values.len()
+ * std::mem::size_of::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
+ mutable_buffer.resize(result_capacity, 0);
+
+ let mutable_slice =
+ mutable_buffer.typed_data_mut::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+
+ let array_data = values.to_data();
+ let input_values = array_data.buffer(0);
+
+ let mut null_bit_buffer = None;
+
+ let nulls_count = values.null_count();
+ let valid_count = values.len() - nulls_count;
+
+ if values.null_count() > 0 {
+ let nulls = array_data.nulls().unwrap();
+
+ let mut validity_buffer = BooleanBufferBuilder::new(values.len());
+ let values_slice;
+
+ if sort_options.nulls_first {
+ values_slice = &mut mutable_slice[nulls_count..];
+ validity_buffer.append_n(nulls_count, false);
+ validity_buffer.append_n(valid_count, true);
+ } else {
+ values_slice = &mut mutable_slice[..valid_count];
+ validity_buffer.append_n(valid_count, true);
+ validity_buffer.append_n(nulls_count, false);
+ }
+
+ for (write_index, index) in nulls.valid_indices().enumerate() {
+ values_slice[write_index] = primitive_values.value(index);
+ }
+
+ values_slice.sort_unstable_by(|a, b| a.compare(*b));
+ if sort_options.descending {
+ values_slice.reverse();
+ }
+
+ null_bit_buffer = Some(validity_buffer.finish().into());
+ } else {
+ mutable_slice.copy_from_slice(&input_values[..values.len()]);
+ mutable_slice.sort_unstable_by(|a, b| a.compare(*b));
Review Comment:
```suggestion
mutable_slice.sort_unstable();
```
Should be the same?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]