tustvold commented on code in PR #4473:
URL: https://github.com/apache/arrow-rs/pull/4473#discussion_r1251253583
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
Review Comment:
This shouldn't be necessary given the constraints on
ArrowPrimitiveType::Native
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let primitive_values = values.as_primitive::<T>();
+
+ let result_capacity = values.len()
+ * std::mem::size_of::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
Review Comment:
Have you considered just using Vec here?
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let primitive_values = values.as_primitive::<T>();
+
+ let result_capacity = values.len()
+ * std::mem::size_of::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
+ mutable_buffer.resize(result_capacity, 0);
+
+ let mutable_slice =
+ mutable_buffer.typed_data_mut::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+
+ let array_data = values.to_data();
+ let input_values = array_data.buffer(0);
+
+ let mut null_bit_buffer = None;
+
+ let nulls_count = values.null_count();
+ let valid_count = values.len() - nulls_count;
+
+ if values.null_count() > 0 {
+ let nulls = array_data.nulls().unwrap();
Review Comment:
```suggestion
if let Some(nulls) = array.nulls().filter(|n| n.null_count() > 0) {
```
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let primitive_values = values.as_primitive::<T>();
+
+ let result_capacity = values.len()
+ * std::mem::size_of::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
+ mutable_buffer.resize(result_capacity, 0);
+
+ let mutable_slice =
+ mutable_buffer.typed_data_mut::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+
+ let array_data = values.to_data();
+ let input_values = array_data.buffer(0);
+
+ let mut null_bit_buffer = None;
Review Comment:
It might be nicer to use an expression style here, rather than using mut
e.g.
let nulls = match array.nulls().filter(|n| n.null_count() > 0) {
Some(nulls) => ...,
None => ...
}
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let primitive_values = values.as_primitive::<T>();
+
+ let result_capacity = values.len()
+ * std::mem::size_of::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
+ mutable_buffer.resize(result_capacity, 0);
+
+ let mutable_slice =
+ mutable_buffer.typed_data_mut::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+
+ let array_data = values.to_data();
+ let input_values = array_data.buffer(0);
Review Comment:
```suggestion
let array = array.as_primitive::<T>();
let input_values = array.values().as_ref();
```
This not only avoids marshaling to ArrayData, but also the code is
technically exploiting an implementation detail that PrimitiveArray returns
ArrayData with a zero offset
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let primitive_values = values.as_primitive::<T>();
+
+ let result_capacity = values.len()
+ * std::mem::size_of::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
+ mutable_buffer.resize(result_capacity, 0);
+
+ let mutable_slice =
+ mutable_buffer.typed_data_mut::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+
+ let array_data = values.to_data();
+ let input_values = array_data.buffer(0);
+
+ let mut null_bit_buffer = None;
+
+ let nulls_count = values.null_count();
+ let valid_count = values.len() - nulls_count;
+
+ if values.null_count() > 0 {
+ let nulls = array_data.nulls().unwrap();
+
+ let mut validity_buffer = BooleanBufferBuilder::new(values.len());
+ let values_slice;
Review Comment:
I personally prefer the expression style, e.g.
```
let values_slice = match sort_options.nulls_first {
true => ...,
false => ...
}
```
It makes it easier to see what is going on and where the value is being set
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +58,137 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type>(values, options),
+ DataType::UInt32 => sort_native_type::<UInt32Type>(values, options),
+ DataType::UInt64 => sort_native_type::<UInt64Type>(values, options),
+ DataType::Float32 => sort_native_type::<Float32Type>(values, options),
+ DataType::Float64 => sort_native_type::<Float64Type>(values, options),
+ DataType::Date32 => sort_native_type::<Date32Type>(values, options),
+ DataType::Date64 => sort_native_type::<Date64Type>(values, options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn sort_native_type<T>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ <T as arrow_array::ArrowPrimitiveType>::Native: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let primitive_values = values.as_primitive::<T>();
+
+ let result_capacity = values.len()
+ * std::mem::size_of::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
+ mutable_buffer.resize(result_capacity, 0);
+
+ let mutable_slice =
+ mutable_buffer.typed_data_mut::<<T as
arrow_array::ArrowPrimitiveType>::Native>();
+
+ let array_data = values.to_data();
+ let input_values = array_data.buffer(0);
+
+ let mut null_bit_buffer = None;
+
+ let nulls_count = values.null_count();
+ let valid_count = values.len() - nulls_count;
+
+ if values.null_count() > 0 {
+ let nulls = array_data.nulls().unwrap();
+
+ let mut validity_buffer = BooleanBufferBuilder::new(values.len());
+ let values_slice;
+
+ if sort_options.nulls_first {
+ values_slice = &mut mutable_slice[nulls_count..];
+ validity_buffer.append_n(nulls_count, false);
+ validity_buffer.append_n(valid_count, true);
+ } else {
+ values_slice = &mut mutable_slice[..valid_count];
+ validity_buffer.append_n(valid_count, true);
+ validity_buffer.append_n(nulls_count, false);
+ }
+
+ for (write_index, index) in nulls.valid_indices().enumerate() {
+ values_slice[write_index] = primitive_values.value(index);
+ }
+
+ values_slice.sort_unstable_by(|a, b| a.compare(*b));
+ if sort_options.descending {
+ values_slice.reverse();
+ }
+
+ null_bit_buffer = Some(validity_buffer.finish().into());
+ } else {
+ mutable_slice.copy_from_slice(&input_values[..values.len()]);
+ mutable_slice.sort_unstable_by(|a, b| a.compare(*b));
+ if sort_options.descending {
+ mutable_slice.reverse();
+ }
Review Comment:
I think given the marginal speed difference it makes sense to save on
codegen by using reverse :+1:
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]