This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 4dfbca6  sort_primitive result is capped to the min of limit or 
values.len (#236)
4dfbca6 is described below

commit 4dfbca6e5791be400d2fd3ae863655445327650e
Author: Michael Edwards <[email protected]>
AuthorDate: Sat May 1 12:19:22 2021 +0200

    sort_primitive result is capped to the min of limit or values.len (#236)
    
    * sort_primitive result is capped to the min of limit or values.len
    
    fixes #235
    
    * Fixed length calculation of nulls to include
    
    * Add more sort_primitive tests for sorts /w limit
---
 arrow/src/compute/kernels/sort.rs | 59 ++++++++++++++++++++++++++++++++++-----
 1 file changed, 52 insertions(+), 7 deletions(-)

diff --git a/arrow/src/compute/kernels/sort.rs 
b/arrow/src/compute/kernels/sort.rs
index 30341b6..9287425 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -487,24 +487,27 @@ where
         len = limit.min(len);
     }
     if !descending {
-        sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1));
+        sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
+            cmp(a.1, b.1)
+        });
     } else {
-        sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse());
+        sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
+            cmp(a.1, b.1).reverse()
+        });
         // reverse to keep a stable ordering
         nulls.reverse();
     }
 
     // collect results directly into a buffer instead of a vec to avoid 
another aligned allocation
-    let mut result = MutableBuffer::new(values.len() * 
std::mem::size_of::<u32>());
+    let result_capacity = len * std::mem::size_of::<u32>();
+    let mut result = MutableBuffer::new(result_capacity);
     // sets len to capacity so we can access the whole buffer as a typed slice
-    result.resize(values.len() * std::mem::size_of::<u32>(), 0);
+    result.resize(result_capacity, 0);
     let result_slice: &mut [u32] = result.typed_data_mut();
 
-    debug_assert_eq!(result_slice.len(), nulls_len + valids_len);
-
     if options.nulls_first {
         let size = nulls_len.min(len);
-        result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls);
+        result_slice[0..size].copy_from_slice(&nulls[0..size]);
         if nulls_len < len {
             insert_valid_values(result_slice, nulls_len, &valids[0..len - 
size]);
         }
@@ -1556,6 +1559,48 @@ mod tests {
             Some(3),
             vec![Some(1.0), Some(2.0), Some(3.0)],
         );
+
+        // valid values less than limit with extra nulls
+        test_sort_primitive_arrays::<Float64Type>(
+            vec![Some(2.0), None, None, Some(1.0)],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: false,
+            }),
+            Some(3),
+            vec![Some(1.0), Some(2.0), None],
+        );
+
+        test_sort_primitive_arrays::<Float64Type>(
+            vec![Some(2.0), None, None, Some(1.0)],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: true,
+            }),
+            Some(3),
+            vec![None, None, Some(1.0)],
+        );
+
+        // more nulls than limit
+        test_sort_primitive_arrays::<Float64Type>(
+            vec![Some(2.0), None, None, None],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: true,
+            }),
+            Some(2),
+            vec![None, None],
+        );
+
+        test_sort_primitive_arrays::<Float64Type>(
+            vec![Some(2.0), None, None, None],
+            Some(SortOptions {
+                descending: false,
+                nulls_first: false,
+            }),
+            Some(2),
+            vec![Some(2.0), None],
+        );
     }
 
     #[test]

Reply via email to