Dandandan commented on code in PR #21083:
URL: https://github.com/apache/datafusion/pull/21083#discussion_r2969206939
##########
datafusion/functions-nested/src/sort.rs:
##########
@@ -206,54 +205,283 @@ fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
field: FieldRef,
sort_options: Option<SortOptions>,
) -> Result<ArrayRef> {
+ let values = list_array.values();
+
+ if values.data_type().is_primitive() {
+ array_sort_primitive(list_array, field, sort_options)
+ } else {
+ array_sort_non_primitive(list_array, field, sort_options)
+ }
+}
+
+/// Sort each row of a primitive-typed ListArray using a custom in-place sort
+/// kernel.
+fn array_sort_primitive<OffsetSize: OffsetSizeTrait>(
+ list_array: &GenericListArray<OffsetSize>,
+ field: FieldRef,
+ sort_options: Option<SortOptions>,
+) -> Result<ArrayRef> {
+ let values = list_array.values().as_ref();
+ downcast_primitive_array! {
+ values => sort_primitive_list(values, list_array, field, sort_options),
+ _ => exec_err!("array_sort: unsupported primitive type")
+ }
+}
+
+fn sort_primitive_list<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
+ prim_values: &PrimitiveArray<T>,
+ list_array: &GenericListArray<OffsetSize>,
+ field: FieldRef,
+ sort_options: Option<SortOptions>,
+) -> Result<ArrayRef>
+where
+ T::Native: ArrowNativeTypeOp,
+{
+ if prim_values.null_count() > 0 {
+ sort_list_with_nulls(prim_values, list_array, field, sort_options)
+ } else {
+ sort_list_no_nulls(prim_values, list_array, field, sort_options)
+ }
+}
+
+/// Fast path for primitive values with no element-level nulls. Copies all
+/// values into a single `Vec` and sorts each row's slice in-place.
+fn sort_list_no_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
+ prim_values: &PrimitiveArray<T>,
+ list_array: &GenericListArray<OffsetSize>,
+ field: FieldRef,
+ sort_options: Option<SortOptions>,
+) -> Result<ArrayRef>
+where
+ T::Native: ArrowNativeTypeOp,
+{
let row_count = list_array.len();
+ let offsets = list_array.offsets();
+ let values_start = offsets[0].as_usize();
+ let values_end = offsets[row_count].as_usize();
+
+ let descending = sort_options.is_some_and(|o| o.descending);
- let mut array_lengths = vec![];
- let mut arrays = vec![];
- for i in 0..row_count {
- if list_array.is_null(i) {
- array_lengths.push(0);
+ // Copy all values into a mutable buffer
+ let mut values: Vec<T::Native> =
+ prim_values.values()[values_start..values_end].to_vec();
+
+ for (row_index, window) in offsets.windows(2).enumerate() {
+ if list_array.is_null(row_index) {
+ continue;
+ }
+ let start = window[0].as_usize() - values_start;
+ let end = window[1].as_usize() - values_start;
+ let slice = &mut values[start..end];
+ if descending {
+ slice.sort_unstable_by(|a, b| b.compare(*a));
} else {
- let arr_ref = list_array.value(i);
-
- // arrow sort kernel does not support Structs, so use
- // lexsort_to_indices instead:
- //
https://github.com/apache/arrow-rs/issues/6911#issuecomment-2562928843
- let sorted_array = match arr_ref.data_type() {
- DataType::Struct(_) => {
- let sort_columns: Vec<SortColumn> = vec![SortColumn {
- values: Arc::clone(&arr_ref),
- options: sort_options,
- }];
- let indices = compute::lexsort_to_indices(&sort_columns,
None)?;
- compute::take(arr_ref.as_ref(), &indices, None)?
- }
- _ => {
- let arr_ref = arr_ref.as_ref();
- compute::sort(arr_ref, sort_options)?
- }
- };
- array_lengths.push(sorted_array.len());
- arrays.push(sorted_array);
+ slice.sort_unstable_by(|a, b| a.compare(*b));
}
}
- let elements = arrays
- .iter()
- .map(|a| a.as_ref())
- .collect::<Vec<&dyn Array>>();
+ let new_offsets = rebase_offsets(offsets);
+ let sorted_values = Arc::new(
+ PrimitiveArray::<T>::new(values.into(), None)
+ .with_data_type(prim_values.data_type().clone()),
+ );
- let list_arr = if elements.is_empty() {
- GenericListArray::<OffsetSize>::new_null(field, row_count)
- } else {
- GenericListArray::<OffsetSize>::new(
- field,
- OffsetBuffer::from_lengths(array_lengths),
- Arc::new(compute::concat(elements.as_slice())?),
- list_array.nulls().cloned(),
+ Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
+ field,
+ new_offsets,
+ sorted_values,
+ list_array.nulls().cloned(),
+ )?))
+}
+
+/// Slow path for primitive values with element-level nulls.
+fn sort_list_with_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
+ prim_values: &PrimitiveArray<T>,
+ list_array: &GenericListArray<OffsetSize>,
+ field: FieldRef,
+ sort_options: Option<SortOptions>,
+) -> Result<ArrayRef>
+where
+ T::Native: ArrowNativeTypeOp,
+{
+ let row_count = list_array.len();
+ let offsets = list_array.offsets();
+ let values_start = offsets[0].as_usize();
+ let values_end = offsets[row_count].as_usize();
+ let total_values = values_end - values_start;
+
+ let descending = sort_options.is_some_and(|o| o.descending);
+ let nulls_first = sort_options.is_none_or(|o| o.nulls_first);
+
+ let mut out_values: Vec<T::Native> = vec![T::Native::default();
total_values];
+ let mut validity = BooleanBufferBuilder::new(total_values);
+
+ let src_nulls = prim_values.nulls().ok_or_else(|| {
+ internal_datafusion_err!(
+ "sort_list_with_nulls called but values have no null buffer"
)
+ })?;
+ let src_values = prim_values.values();
+
+ for (row_index, window) in offsets.windows(2).enumerate() {
+ let start = window[0].as_usize();
+ let end = window[1].as_usize();
+ let row_len = end - start;
+ let out_start = start - values_start;
+
+ if list_array.is_null(row_index) || row_len == 0 {
+ validity.append_n(row_len, false);
+ continue;
+ }
+
+ let null_count = src_nulls.slice(start, row_len).null_count();
+ let valid_count = row_len - null_count;
+
+ // Compact valid values directly into the target region of the output
+ // buffer: after nulls (if nulls_first) or at the start (if
nulls_last).
+ let valid_offset = if nulls_first { null_count } else { 0 };
+ let mut write_pos = out_start + valid_offset;
+ for i in start..end {
+ if src_nulls.is_valid(i) {
+ out_values[write_pos] = src_values[i];
+ write_pos += 1;
+ }
+ }
+
+ let valid_slice = &mut out_values
+ [out_start + valid_offset..out_start + valid_offset + valid_count];
+ if descending {
+ valid_slice.sort_unstable_by(|a, b| b.compare(*a));
+ } else {
+ valid_slice.sort_unstable_by(|a, b| a.compare(*b));
+ }
+
+ // Build validity bits
+ if nulls_first {
+ validity.append_n(null_count, false);
+ validity.append_n(valid_count, true);
+ } else {
+ validity.append_n(valid_count, true);
+ validity.append_n(null_count, false);
+ }
+ }
+
+ let new_offsets = rebase_offsets(offsets);
+
+ let null_buffer = NullBuffer::from(validity.finish());
+ let sorted_values = Arc::new(
+ PrimitiveArray::<T>::new(out_values.into(), Some(null_buffer))
+ .with_data_type(prim_values.data_type().clone()),
+ );
+
+ Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
+ field,
+ new_offsets,
+ sorted_values,
+ list_array.nulls().cloned(),
+ )?))
+}
+
+/// Sort a non-pritive-typed ListArray by converting all rows at once using
+/// `RowConverter`, and then sort row indices by comparing encoded bytes (sort
+/// direction and null ordering are baked into the encoding), and materialize
+/// the result with a single `take()`.
+fn array_sort_non_primitive<OffsetSize: OffsetSizeTrait>(
+ list_array: &GenericListArray<OffsetSize>,
+ field: FieldRef,
+ sort_options: Option<SortOptions>,
+) -> Result<ArrayRef> {
+ let row_count = list_array.len();
+ let values = list_array.values();
+ let offsets = list_array.offsets();
+ let values_start = offsets[0].as_usize();
+ let total_values = offsets[row_count].as_usize() - values_start;
+
+ let converter = RowConverter::new(vec![SortField::new_with_options(
Review Comment:
The Claude code looks close to what I had in mind except for the
`make_comparator` and some non-null specializilation / `Vec::push` which can
not be inlined and generate slow / branchy code.
I guess for strings you could also do the same tricks as used in arrow
kernels, create: create some inlined key / small string as well for fast
comparisons.
RowFilter is fine of course, but there is some higher fixed overhead
upfront and has some higher space usage as well, so for single columns I think
type specialization always wins.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]