jhorstmann commented on a change in pull request #8092:
URL: https://github.com/apache/arrow/pull/8092#discussion_r483965314



##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -148,58 +206,165 @@ impl Default for SortOptions {
     }
 }
 
-/// Sort primitive values, excluding floats
+/// Sort primitive values
 fn sort_primitive<T>(
     values: &ArrayRef,
-    value_indices: Vec<usize>,
+    value_indices: Vec<u32>,
     null_indices: Vec<u32>,
+    nan_indices: Vec<u32>,
     options: &SortOptions,
 ) -> Result<UInt32Array>
 where
     T: ArrowPrimitiveType,
     T::Native: std::cmp::PartialOrd,
 {
     let values = as_primitive_array::<T>(values);
+    sort_primitive_typed(values, value_indices, null_indices, nan_indices, 
options)
+}
+
+fn sort_primitive_typed<T>(
+    values: &PrimitiveArray<T>,
+    value_indices: Vec<u32>,
+    null_indices: Vec<u32>,
+    nan_indices: Vec<u32>,
+    options: &SortOptions,
+) -> Result<UInt32Array>
+where
+    T: ArrowPrimitiveType,
+    T::Native: std::cmp::PartialOrd,
+{
     // create tuples that are used for sorting
     let mut valids = value_indices
         .into_iter()
-        .map(|index| (index as u32, values.value(index)))
+        .map(|index| (index, values.value(index as usize)))
         .collect::<Vec<(u32, T::Native)>>();
+
+    let valids_len = valids.len();
+
     let mut nulls = null_indices;
+    let mut nans = nan_indices;
+
     if !options.descending {
-        valids.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or_else(|| 
Ordering::Greater));
+        valids.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("unexpected NaN"));
     } else {
-        valids.sort_by(|a, b| {
-            a.1.partial_cmp(&b.1)
-                .unwrap_or_else(|| Ordering::Greater)
-                .reverse()
-        });
+        valids.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("unexpected 
NaN").reverse());
+        // reverse to keep a stable ordering
+        nans.reverse();
         nulls.reverse();
     }
-    // collect the order of valid tuples
-    let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
 
-    if options.nulls_first {
-        nulls.append(&mut valid_indices);
-        return Ok(UInt32Array::from(nulls));
+    // collect results directly into a buffer instead of a vec to avoid 
another aligned allocation
+    let mut result = MutableBuffer::new(values.len() * 
std::mem::size_of::<u32>());
+    // sets len to capacity so we can access the whole buffer as a typed slice
+    result.resize(values.len() * std::mem::size_of::<u32>())?;
+    {
+        let append_valids = move |dst_slice: &mut [u32]| {
+            debug_assert_eq!(dst_slice.len(), valids_len);
+            dst_slice
+                .iter_mut()
+                .zip(valids.into_iter())
+                .for_each(|(dst, src)| *dst = src.0)
+        };
+
+        let result_slice: &mut [u32] = result.typed_data_mut();
+
+        debug_assert_eq!(result_slice.len(), nulls.len() + nans.len() + 
valids_len);
+
+        if options.nulls_first {

Review comment:
       I refactored this a bit to extract the similar logic with `if descending 
... else` and also added comments there




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to