neilconway commented on code in PR #21083:
URL: https://github.com/apache/datafusion/pull/21083#discussion_r2968470345


##########
datafusion/functions-nested/src/sort.rs:
##########
@@ -206,54 +205,283 @@ fn array_sort_generic<OffsetSize: OffsetSizeTrait>(
     field: FieldRef,
     sort_options: Option<SortOptions>,
 ) -> Result<ArrayRef> {
+    let values = list_array.values();
+
+    if values.data_type().is_primitive() {
+        array_sort_primitive(list_array, field, sort_options)
+    } else {
+        array_sort_non_primitive(list_array, field, sort_options)
+    }
+}
+
+/// Sort each row of a primitive-typed ListArray using a custom in-place sort
+/// kernel.
+fn array_sort_primitive<OffsetSize: OffsetSizeTrait>(
+    list_array: &GenericListArray<OffsetSize>,
+    field: FieldRef,
+    sort_options: Option<SortOptions>,
+) -> Result<ArrayRef> {
+    let values = list_array.values().as_ref();
+    downcast_primitive_array! {
+        values => sort_primitive_list(values, list_array, field, sort_options),
+        _ => exec_err!("array_sort: unsupported primitive type")
+    }
+}
+
+fn sort_primitive_list<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
+    prim_values: &PrimitiveArray<T>,
+    list_array: &GenericListArray<OffsetSize>,
+    field: FieldRef,
+    sort_options: Option<SortOptions>,
+) -> Result<ArrayRef>
+where
+    T::Native: ArrowNativeTypeOp,
+{
+    if prim_values.null_count() > 0 {
+        sort_list_with_nulls(prim_values, list_array, field, sort_options)
+    } else {
+        sort_list_no_nulls(prim_values, list_array, field, sort_options)
+    }
+}
+
+/// Fast path for primitive values with no element-level nulls. Copies all
+/// values into a single `Vec` and sorts each row's slice in-place.
+fn sort_list_no_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
+    prim_values: &PrimitiveArray<T>,
+    list_array: &GenericListArray<OffsetSize>,
+    field: FieldRef,
+    sort_options: Option<SortOptions>,
+) -> Result<ArrayRef>
+where
+    T::Native: ArrowNativeTypeOp,
+{
     let row_count = list_array.len();
+    let offsets = list_array.offsets();
+    let values_start = offsets[0].as_usize();
+    let values_end = offsets[row_count].as_usize();
+
+    let descending = sort_options.is_some_and(|o| o.descending);
 
-    let mut array_lengths = vec![];
-    let mut arrays = vec![];
-    for i in 0..row_count {
-        if list_array.is_null(i) {
-            array_lengths.push(0);
+    // Copy all values into a mutable buffer
+    let mut values: Vec<T::Native> =
+        prim_values.values()[values_start..values_end].to_vec();
+
+    for (row_index, window) in offsets.windows(2).enumerate() {
+        if list_array.is_null(row_index) {
+            continue;
+        }
+        let start = window[0].as_usize() - values_start;
+        let end = window[1].as_usize() - values_start;
+        let slice = &mut values[start..end];
+        if descending {
+            slice.sort_unstable_by(|a, b| b.compare(*a));
         } else {
-            let arr_ref = list_array.value(i);
-
-            // arrow sort kernel does not support Structs, so use
-            // lexsort_to_indices instead:
-            // 
https://github.com/apache/arrow-rs/issues/6911#issuecomment-2562928843
-            let sorted_array = match arr_ref.data_type() {
-                DataType::Struct(_) => {
-                    let sort_columns: Vec<SortColumn> = vec![SortColumn {
-                        values: Arc::clone(&arr_ref),
-                        options: sort_options,
-                    }];
-                    let indices = compute::lexsort_to_indices(&sort_columns, 
None)?;
-                    compute::take(arr_ref.as_ref(), &indices, None)?
-                }
-                _ => {
-                    let arr_ref = arr_ref.as_ref();
-                    compute::sort(arr_ref, sort_options)?
-                }
-            };
-            array_lengths.push(sorted_array.len());
-            arrays.push(sorted_array);
+            slice.sort_unstable_by(|a, b| a.compare(*b));
         }
     }
 
-    let elements = arrays
-        .iter()
-        .map(|a| a.as_ref())
-        .collect::<Vec<&dyn Array>>();
+    let new_offsets = rebase_offsets(offsets);
+    let sorted_values = Arc::new(
+        PrimitiveArray::<T>::new(values.into(), None)
+            .with_data_type(prim_values.data_type().clone()),
+    );
 
-    let list_arr = if elements.is_empty() {
-        GenericListArray::<OffsetSize>::new_null(field, row_count)
-    } else {
-        GenericListArray::<OffsetSize>::new(
-            field,
-            OffsetBuffer::from_lengths(array_lengths),
-            Arc::new(compute::concat(elements.as_slice())?),
-            list_array.nulls().cloned(),
+    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
+        field,
+        new_offsets,
+        sorted_values,
+        list_array.nulls().cloned(),
+    )?))
+}
+
+/// Slow path for primitive values with element-level nulls.
+fn sort_list_with_nulls<T: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
+    prim_values: &PrimitiveArray<T>,
+    list_array: &GenericListArray<OffsetSize>,
+    field: FieldRef,
+    sort_options: Option<SortOptions>,
+) -> Result<ArrayRef>
+where
+    T::Native: ArrowNativeTypeOp,
+{
+    let row_count = list_array.len();
+    let offsets = list_array.offsets();
+    let values_start = offsets[0].as_usize();
+    let values_end = offsets[row_count].as_usize();
+    let total_values = values_end - values_start;
+
+    let descending = sort_options.is_some_and(|o| o.descending);
+    let nulls_first = sort_options.is_none_or(|o| o.nulls_first);
+
+    let mut out_values: Vec<T::Native> = vec![T::Native::default(); 
total_values];
+    let mut validity = BooleanBufferBuilder::new(total_values);
+
+    let src_nulls = prim_values.nulls().ok_or_else(|| {
+        internal_datafusion_err!(
+            "sort_list_with_nulls called but values have no null buffer"
         )
+    })?;
+    let src_values = prim_values.values();
+
+    for (row_index, window) in offsets.windows(2).enumerate() {
+        let start = window[0].as_usize();
+        let end = window[1].as_usize();
+        let row_len = end - start;
+        let out_start = start - values_start;
+
+        if list_array.is_null(row_index) || row_len == 0 {
+            validity.append_n(row_len, false);
+            continue;
+        }
+
+        let null_count = src_nulls.slice(start, row_len).null_count();
+        let valid_count = row_len - null_count;
+
+        // Compact valid values directly into the target region of the output
+        // buffer: after nulls (if nulls_first) or at the start (if 
nulls_last).
+        let valid_offset = if nulls_first { null_count } else { 0 };
+        let mut write_pos = out_start + valid_offset;
+        for i in start..end {
+            if src_nulls.is_valid(i) {
+                out_values[write_pos] = src_values[i];
+                write_pos += 1;
+            }
+        }
+
+        let valid_slice = &mut out_values
+            [out_start + valid_offset..out_start + valid_offset + valid_count];
+        if descending {
+            valid_slice.sort_unstable_by(|a, b| b.compare(*a));
+        } else {
+            valid_slice.sort_unstable_by(|a, b| a.compare(*b));
+        }
+
+        // Build validity bits
+        if nulls_first {
+            validity.append_n(null_count, false);
+            validity.append_n(valid_count, true);
+        } else {
+            validity.append_n(valid_count, true);
+            validity.append_n(null_count, false);
+        }
+    }
+
+    let new_offsets = rebase_offsets(offsets);
+
+    let null_buffer = NullBuffer::from(validity.finish());
+    let sorted_values = Arc::new(
+        PrimitiveArray::<T>::new(out_values.into(), Some(null_buffer))
+            .with_data_type(prim_values.data_type().clone()),
+    );
+
+    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
+        field,
+        new_offsets,
+        sorted_values,
+        list_array.nulls().cloned(),
+    )?))
+}
+
+/// Sort a non-pritive-typed ListArray by converting all rows at once using
+/// `RowConverter`, and then sort row indices by comparing encoded bytes (sort
+/// direction and null ordering are baked into the encoding), and materialize
+/// the result with a single `take()`.
+fn array_sort_non_primitive<OffsetSize: OffsetSizeTrait>(
+    list_array: &GenericListArray<OffsetSize>,
+    field: FieldRef,
+    sort_options: Option<SortOptions>,
+) -> Result<ArrayRef> {
+    let row_count = list_array.len();
+    let values = list_array.values();
+    let offsets = list_array.offsets();
+    let values_start = offsets[0].as_usize();
+    let total_values = offsets[row_count].as_usize() - values_start;
+
+    let converter = RowConverter::new(vec![SortField::new_with_options(

Review Comment:
   I briefly considered something like that, but I figured that all the pointer 
chasing would be pretty expensive. You're right that it's worth comparing 
though.
   
   Here's a [quick Claude-generated 
version](https://gist.github.com/neilconway/c53faad4f597dfd1575d6a029dbc571e) 
-- lmk if you had something else in mind.
   
   Benchmarking it against the `RowComparator` approach, `RowComparator` wins 
for medium-sized arrays (20 elements) and larger, and loses to the index-based 
comparison approach for small arrays:
   
   ```
     ┌─────────────┬──────────┬─────────────────┬─────────────────┐
     │  Benchmark  │   main   │  RowConverter   │ make_comparator │
     ├─────────────┼──────────┼─────────────────┼─────────────────┤
     │ string/5    │ 2.12 ms  │ 727 µs (-66%)   │ 608 µs (-71%)   │
     ├─────────────┼──────────┼─────────────────┼─────────────────┤
     │ string/20   │ 5.94 ms  │ 4.42 ms (-26%)  │ 4.76 ms (-20%)  │
     ├─────────────┼──────────┼─────────────────┼─────────────────┤
     │ string/100  │ 26.8 ms  │ 22.6 ms (-16%)  │ 25.1 ms (-6%)   │
     ├─────────────┼──────────┼─────────────────┼─────────────────┤
     │ string/1000 │ 404.9 ms │ 293.1 ms (-28%) │ 403.9 ms (~0%)  │
     └─────────────┴──────────┴─────────────────┴─────────────────┘
   ```
   
   Not sure offhand which typical real-world workloads look like; lmk if you 
have a view.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to