psvri commented on code in PR #4473:
URL: https://github.com/apache/arrow-rs/pull/4473#discussion_r1251110673
##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +60,211 @@ pub fn sort(
values: &dyn Array,
options: Option<SortOptions>,
) -> Result<ArrayRef, ArrowError> {
- if let DataType::RunEndEncoded(_, _) = values.data_type() {
- return sort_run(values, options, None);
+ match values.data_type() {
+ DataType::Int8 => sort_native_type::<Int8Type, i8>(values, options),
+ DataType::Int16 => sort_native_type::<Int16Type, i16>(values, options),
+ DataType::Int32 => sort_native_type::<Int32Type, i32>(values, options),
+ DataType::Int64 => sort_native_type::<Int64Type, i64>(values, options),
+ DataType::UInt8 => sort_native_type::<UInt8Type, u8>(values, options),
+ DataType::UInt16 => sort_native_type::<UInt16Type, u16>(values,
options),
+ DataType::UInt32 => sort_native_type::<UInt32Type, u32>(values,
options),
+ DataType::UInt64 => sort_native_type::<UInt64Type, u64>(values,
options),
+ DataType::Float32 => sort_native_type::<Float32Type, f32>(values,
options),
+ DataType::Float64 => sort_native_type::<Float64Type, f64>(values,
options),
+ DataType::Date32 => sort_native_type::<Date32Type, i32>(values,
options),
+ DataType::Date64 => sort_native_type::<Date64Type, i64>(values,
options),
+ DataType::Time32(TimeUnit::Second) => {
+ sort_native_type::<Time32SecondType, i32>(values, options)
+ }
+ DataType::Time32(TimeUnit::Millisecond) => {
+ sort_native_type::<Time32MillisecondType, i32>(values, options)
+ }
+ DataType::Time64(TimeUnit::Microsecond) => {
+ sort_native_type::<Time64MicrosecondType, i64>(values, options)
+ }
+ DataType::Time64(TimeUnit::Nanosecond) => {
+ sort_native_type::<Time64NanosecondType, i64>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Second, _) => {
+ sort_native_type::<TimestampSecondType, i64>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, _) => {
+ sort_native_type::<TimestampMillisecondType, i64>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ sort_native_type::<TimestampMicrosecondType, i64>(values, options)
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+ sort_native_type::<TimestampNanosecondType, i64>(values, options)
+ }
+ DataType::Interval(IntervalUnit::YearMonth) => {
+ sort_native_type::<IntervalYearMonthType, i32>(values, options)
+ }
+ DataType::Interval(IntervalUnit::DayTime) => {
+ sort_native_type::<IntervalDayTimeType, i64>(values, options)
+ }
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ sort_native_type::<IntervalMonthDayNanoType, i128>(values, options)
+ }
+ DataType::Duration(TimeUnit::Second) => {
+ sort_native_type::<DurationSecondType, i64>(values, options)
+ }
+ DataType::Duration(TimeUnit::Millisecond) => {
+ sort_native_type::<DurationMillisecondType, i64>(values, options)
+ }
+ DataType::Duration(TimeUnit::Microsecond) => {
+ sort_native_type::<DurationMicrosecondType, i64>(values, options)
+ }
+ DataType::Duration(TimeUnit::Nanosecond) => {
+ sort_native_type::<DurationNanosecondType, i64>(values, options)
+ }
+ DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+ _ => {
+ let indices = sort_to_indices(values, options, None)?;
+ take(values, &indices, None)
+ }
}
- let indices = sort_to_indices(values, options, None)?;
- take(values, &indices, None)
+}
+
+fn compress_store<U>(input: *const U, mut output: *mut U, mask: u8) -> isize
+where
+ U: ArrowNativeType,
+{
+ let mut offset = 0;
+ if mask != 0 {
+ for i in 0..8 {
+ if (mask & (1 << i)) != 0 {
+ // This is safe since a valid bit i.e bit set to 1 indicates a
valid value
+ unsafe {
+ *output = *input.offset(i);
+ offset += 1;
+ output = output.offset(1);
+ }
+ }
+ }
+ }
+ offset
+}
+
+fn create_null_buffer(
+ valid_count: usize,
+ nulls_count: usize,
+ length: usize,
+ sort_options: SortOptions,
+) -> Option<Buffer> {
+ let null_capacity = (length / 8) + (length % 8 != 0) as usize;
+ let mut mutable_null_buffer = MutableBuffer::new(null_capacity * 8);
+ mutable_null_buffer.resize(null_capacity, 0);
+
+ let mutable_null_buffer_slice = mutable_null_buffer.as_slice_mut();
+
+ if valid_count > 0 {
+ let mut count = valid_count;
+ let mut index = 0;
+ if sort_options.nulls_first {
+ let remaining_nulls = nulls_count % 8;
+ index = nulls_count / 8;
+
+ if remaining_nulls != 0 {
+ let valid_values_count = min(8 - remaining_nulls, valid_count);
+ mutable_null_buffer_slice[index] =
+ ((1 << valid_values_count) - 1) << remaining_nulls;
+ count -= valid_values_count;
+ index += 1;
+ }
+ }
+ while count >= 8 {
+ mutable_null_buffer_slice[index] = u8::MAX;
+ index += 1;
+ count -= 8;
+ }
+ if count != 0 {
+ mutable_null_buffer_slice[index] = (1 << count) - 1;
+ }
+ }
+
+ Some(mutable_null_buffer.into())
+}
+
+fn sort_native_type<T, U>(
+ values: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+ T: ArrowPrimitiveType,
+ U: ArrowNativeTypeOp,
+{
+ let sort_options = options.unwrap_or_default();
+ let values = values.as_primitive::<T>();
+
+ let result_capacity = values.len() * std::mem::size_of::<U>();
+ let mut mutable_buffer = MutableBuffer::new(result_capacity);
+ mutable_buffer.resize(result_capacity, 0);
+ let mutable_slice: &mut [U] = mutable_buffer.typed_data_mut();
+
+ let array_data = values.to_data();
+ let input_values: &[U] = array_data.buffer(0);
+
+ let mut null_bit_buffer = None;
+
+ let nulls_count = values.null_count();
+ let valid_count = values.len() - nulls_count;
+
+ if values.null_count() > 0 {
+ let nulls = array_data.nulls().unwrap();
+ let null_buffer = nulls.buffer().as_slice();
+
+ let mut mutable_slice_ptr = mutable_slice.as_mut_ptr();
+ let mut input_values_ptr = input_values.as_ptr();
+
+ if sort_options.nulls_first {
+ // This is safe since the offset in in bounds
+ unsafe {
+ mutable_slice_ptr = mutable_slice_ptr.add(values.null_count());
+ }
+ }
+
+ // This is safe since we are in bounds
+ let values_slice =
+ unsafe { slice::from_raw_parts_mut(mutable_slice_ptr, valid_count)
};
+
+ for mask in null_buffer {
+ let written_count =
+ compress_store::<U>(input_values_ptr, mutable_slice_ptr,
*mask);
+ // This is safe as the offset increments are within bounds
+ unsafe {
+ input_values_ptr = input_values_ptr.offset(8);
+ mutable_slice_ptr = mutable_slice_ptr.offset(written_count);
+ }
+ }
+
+ values_slice.sort_unstable_by(|a, b| a.compare(*b));
+ if sort_options.descending {
+ values_slice.reverse();
+ }
+
+ null_bit_buffer =
+ create_null_buffer(valid_count, nulls_count, values.len(),
sort_options);
+ } else {
+ mutable_slice.copy_from_slice(input_values);
+ mutable_slice.sort_unstable_by(|a, b| a.compare(*b));
+ if sort_options.descending {
+ mutable_slice.reverse();
+ }
+ }
+ // This is safe since data types match
Review Comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]