This is an automated email from the ASF dual-hosted git repository.

jeffreyvo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 85080638fc feat(arrow-ord): support boolean in `rank` and add tests 
for sorting lists of booleans (#6912)
85080638fc is described below

commit 85080638fc223a38950c2d6bda403190037829dc
Author: Raz Luvaton <[email protected]>
AuthorDate: Sat Jan 25 12:26:56 2025 +0200

    feat(arrow-ord): support boolean in `rank` and add tests for sorting lists 
of booleans (#6912)
    
    * feat(arrow-ord): support boolean in rank
    
    * add tests for sorting list of booleans
    
    * improve boolean rank performance
    
    * fix boolean rank to be sparse rather than dense
    
    * format
    
    * add test for boolean array without nulls
---
 arrow-ord/src/rank.rs | 165 +++++++++++++++++++++-
 arrow-ord/src/sort.rs | 382 +++++++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 544 insertions(+), 3 deletions(-)

diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs
index e61cebef38..1b0d2a7e63 100644
--- a/arrow-ord/src/rank.rs
+++ b/arrow-ord/src/rank.rs
@@ -19,7 +19,9 @@
 
 use arrow_array::cast::AsArray;
 use arrow_array::types::*;
-use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, 
GenericByteArray};
+use arrow_array::{
+    downcast_primitive_array, Array, ArrowNativeTypeOp, BooleanArray, 
GenericByteArray,
+};
 use arrow_buffer::NullBuffer;
 use arrow_schema::{ArrowError, DataType, SortOptions};
 use std::cmp::Ordering;
@@ -29,7 +31,11 @@ pub(crate) fn can_rank(data_type: &DataType) -> bool {
     data_type.is_primitive()
         || matches!(
             data_type,
-            DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | 
DataType::LargeBinary
+            DataType::Boolean
+                | DataType::Utf8
+                | DataType::LargeUtf8
+                | DataType::Binary
+                | DataType::LargeBinary
         )
 }
 
@@ -49,6 +55,7 @@ pub fn rank(array: &dyn Array, options: Option<SortOptions>) 
-> Result<Vec<u32>,
     let options = options.unwrap_or_default();
     let ranks = downcast_primitive_array! {
         array => primitive_rank(array.values(), array.nulls(), options),
+        DataType::Boolean => boolean_rank(array.as_boolean(), options),
         DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
         DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), 
options),
         DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), 
options),
@@ -135,6 +142,84 @@ where
     out
 }
 
+/// Return the index for the rank when ranking boolean array
+///
+/// The index is calculated as follows:
+/// if is_null is true, the index is 2
+/// if is_null is false and the value is true, the index is 1
+/// otherwise, the index is 0
+///
+/// false is 0 and true is 1 because these are the value when cast to number
+#[inline]
+fn get_boolean_rank_index(value: bool, is_null: bool) -> usize {
+    let is_null_num = is_null as usize;
+    (is_null_num << 1) | (value as usize & !is_null_num)
+}
+
+#[inline(never)]
+fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec<u32> {
+    let null_count = array.null_count() as u32;
+    let true_count = array.true_count() as u32;
+    let false_count = array.len() as u32 - null_count - true_count;
+
+    // Rank values for [false, true, null] in that order
+    //
+    // The value for a rank is last value rank + own value count
+    // this means that if we have the following order: `false`, `true` and 
then `null`
+    // the ranks will be:
+    // - false: false_count
+    // - true: false_count + true_count
+    // - null: false_count + true_count + null_count
+    //
+    // If we have the following order: `null`, `false` and then `true`
+    // the ranks will be:
+    // - false: null_count + false_count
+    // - true: null_count + false_count + true_count
+    // - null: null_count
+    //
+    // You will notice that the last rank is always the total length of the 
array but we don't use it for readability on how the rank is calculated
+    let ranks_index: [u32; 3] = match (options.descending, 
options.nulls_first) {
+        // The order is null, true, false
+        (true, true) => [
+            null_count + true_count + false_count,
+            null_count + true_count,
+            null_count,
+        ],
+        // The order is true, false, null
+        (true, false) => [
+            true_count + false_count,
+            true_count,
+            true_count + false_count + null_count,
+        ],
+        // The order is null, false, true
+        (false, true) => [
+            null_count + false_count,
+            null_count + false_count + true_count,
+            null_count,
+        ],
+        // The order is false, true, null
+        (false, false) => [
+            false_count,
+            false_count + true_count,
+            false_count + true_count + null_count,
+        ],
+    };
+
+    match array.nulls().filter(|n| n.null_count() > 0) {
+        Some(n) => array
+            .values()
+            .iter()
+            .zip(n.iter())
+            .map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value, 
!is_valid)])
+            .collect::<Vec<u32>>(),
+        None => array
+            .values()
+            .iter()
+            .map(|value| ranks_index[value as usize])
+            .collect::<Vec<u32>>(),
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -177,6 +262,82 @@ mod tests {
         assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
     }
 
+    #[test]
+    fn test_get_boolean_rank_index() {
+        assert_eq!(get_boolean_rank_index(true, true), 2);
+        assert_eq!(get_boolean_rank_index(false, true), 2);
+        assert_eq!(get_boolean_rank_index(true, false), 1);
+        assert_eq!(get_boolean_rank_index(false, false), 0);
+    }
+
+    #[test]
+    fn test_nullable_booleans() {
+        let descending = SortOptions {
+            descending: true,
+            nulls_first: true,
+        };
+
+        let nulls_last = SortOptions {
+            descending: false,
+            nulls_first: false,
+        };
+
+        let nulls_last_descending = SortOptions {
+            descending: true,
+            nulls_first: false,
+        };
+
+        let a = BooleanArray::from(vec![Some(true), Some(true), None, 
Some(false), Some(false)]);
+        let res = rank(&a, None).unwrap();
+        assert_eq!(res, &[5, 5, 1, 3, 3]);
+
+        let res = rank(&a, Some(descending)).unwrap();
+        assert_eq!(res, &[3, 3, 1, 5, 5]);
+
+        let res = rank(&a, Some(nulls_last)).unwrap();
+        assert_eq!(res, &[4, 4, 5, 2, 2]);
+
+        let res = rank(&a, Some(nulls_last_descending)).unwrap();
+        assert_eq!(res, &[2, 2, 5, 4, 4]);
+
+        // Test with non-zero null values
+        let nulls = NullBuffer::from(vec![true, true, false, true, true]);
+        let a = BooleanArray::new(vec![true, true, true, false, false].into(), 
Some(nulls));
+        let res = rank(&a, None).unwrap();
+        assert_eq!(res, &[5, 5, 1, 3, 3]);
+    }
+
+    #[test]
+    fn test_booleans() {
+        let descending = SortOptions {
+            descending: true,
+            nulls_first: true,
+        };
+
+        let nulls_last = SortOptions {
+            descending: false,
+            nulls_first: false,
+        };
+
+        let nulls_last_descending = SortOptions {
+            descending: true,
+            nulls_first: false,
+        };
+
+        let a = BooleanArray::from(vec![true, false, false, false, true]);
+        let res = rank(&a, None).unwrap();
+        assert_eq!(res, &[5, 3, 3, 3, 5]);
+
+        let res = rank(&a, Some(descending)).unwrap();
+        assert_eq!(res, &[2, 5, 5, 5, 2]);
+
+        let res = rank(&a, Some(nulls_last)).unwrap();
+        assert_eq!(res, &[5, 3, 3, 3, 5]);
+
+        let res = rank(&a, Some(nulls_last_descending)).unwrap();
+        assert_eq!(res, &[2, 5, 5, 5, 2]);
+    }
+
     #[test]
     fn test_bytes() {
         let v = vec!["foo", "fo", "bar", "bar"];
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index 51a6659e63..fa5e2b8b2f 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -785,12 +785,14 @@ impl LexicographicalComparator {
 mod tests {
     use super::*;
     use arrow_array::builder::{
-        FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder,
+        BooleanBuilder, FixedSizeListBuilder, GenericListBuilder, 
Int64Builder, ListBuilder,
+        PrimitiveRunBuilder,
     };
     use arrow_buffer::{i256, NullBuffer};
     use arrow_schema::Field;
     use half::f16;
     use rand::rngs::StdRng;
+    use rand::seq::SliceRandom;
     use rand::{Rng, RngCore, SeedableRng};
 
     fn create_decimal128_array(data: Vec<Option<i128>>) -> Decimal128Array {
@@ -1541,6 +1543,384 @@ mod tests {
         );
     }
 
+    /// Test sort boolean on each permutation of with/without limit and 
GenericListArray/FixedSizeListArray
+    ///
+    /// The input data must have the same length for all list items so that we 
can test FixedSizeListArray
+    ///
+    fn test_every_config_sort_boolean_list_arrays(
+        data: Vec<Option<Vec<Option<bool>>>>,
+        options: Option<SortOptions>,
+        expected_data: Vec<Option<Vec<Option<bool>>>>,
+    ) {
+        let first_length = data
+            .iter()
+            .find_map(|x| x.as_ref().map(|x| x.len()))
+            .unwrap_or(0);
+        let first_non_match_length = data
+            .iter()
+            .map(|x| x.as_ref().map(|x| x.len()).unwrap_or(first_length))
+            .position(|x| x != first_length);
+
+        assert_eq!(
+            first_non_match_length, None,
+            "All list items should have the same length {first_length}, input 
data is invalid"
+        );
+
+        let first_non_match_length = expected_data
+            .iter()
+            .map(|x| x.as_ref().map(|x| x.len()).unwrap_or(first_length))
+            .position(|x| x != first_length);
+
+        assert_eq!(
+            first_non_match_length, None,
+            "All list items should have the same length {first_length}, 
expected data is invalid"
+        );
+
+        let limit = expected_data.len().saturating_div(2);
+
+        for &with_limit in &[false, true] {
+            let (limit, expected_data) = if with_limit {
+                (
+                    Some(limit),
+                    expected_data.iter().take(limit).cloned().collect(),
+                )
+            } else {
+                (None, expected_data.clone())
+            };
+
+            for &fixed_length in &[None, Some(first_length as i32)] {
+                test_sort_boolean_list_arrays(
+                    data.clone(),
+                    options,
+                    limit,
+                    expected_data.clone(),
+                    fixed_length,
+                );
+            }
+        }
+    }
+
+    fn test_sort_boolean_list_arrays(
+        data: Vec<Option<Vec<Option<bool>>>>,
+        options: Option<SortOptions>,
+        limit: Option<usize>,
+        expected_data: Vec<Option<Vec<Option<bool>>>>,
+        fixed_length: Option<i32>,
+    ) {
+        fn build_fixed_boolean_list_array(
+            data: Vec<Option<Vec<Option<bool>>>>,
+            fixed_length: i32,
+        ) -> ArrayRef {
+            let mut builder = FixedSizeListBuilder::new(
+                BooleanBuilder::with_capacity(fixed_length as usize),
+                fixed_length,
+            );
+            for sublist in data {
+                match sublist {
+                    Some(sublist) => {
+                        builder.values().extend(sublist);
+                        builder.append(true);
+                    }
+                    None => {
+                        builder
+                            .values()
+                            .extend(std::iter::repeat(None).take(fixed_length 
as usize));
+                        builder.append(false);
+                    }
+                }
+            }
+            Arc::new(builder.finish()) as ArrayRef
+        }
+
+        fn build_generic_boolean_list_array<OffsetSize: OffsetSizeTrait>(
+            data: Vec<Option<Vec<Option<bool>>>>,
+        ) -> ArrayRef {
+            let mut builder = GenericListBuilder::<OffsetSize, 
_>::new(BooleanBuilder::new());
+            builder.extend(data);
+            Arc::new(builder.finish()) as ArrayRef
+        }
+
+        // for FixedSizedList
+        if let Some(length) = fixed_length {
+            let input = build_fixed_boolean_list_array(data.clone(), length);
+            let sorted = match limit {
+                Some(_) => sort_limit(&(input as ArrayRef), options, 
limit).unwrap(),
+                _ => sort(&(input as ArrayRef), options).unwrap(),
+            };
+            let expected = 
build_fixed_boolean_list_array(expected_data.clone(), length);
+
+            assert_eq!(&sorted, &expected);
+        }
+
+        // for List
+        let input = build_generic_boolean_list_array::<i32>(data.clone());
+        let sorted = match limit {
+            Some(_) => sort_limit(&(input as ArrayRef), options, 
limit).unwrap(),
+            _ => sort(&(input as ArrayRef), options).unwrap(),
+        };
+        let expected = 
build_generic_boolean_list_array::<i32>(expected_data.clone());
+
+        assert_eq!(&sorted, &expected);
+
+        // for LargeList
+        let input = build_generic_boolean_list_array::<i64>(data.clone());
+        let sorted = match limit {
+            Some(_) => sort_limit(&(input as ArrayRef), options, 
limit).unwrap(),
+            _ => sort(&(input as ArrayRef), options).unwrap(),
+        };
+        let expected = 
build_generic_boolean_list_array::<i64>(expected_data.clone());
+
+        assert_eq!(&sorted, &expected);
+    }
+
+    #[test]
+    fn test_sort_list_of_booleans() {
+        // These are all the possible combinations of boolean values
+        // There are 3^3 + 1 = 28 possible combinations (3 values to permutate 
- [true, false, null] and 1 None value)
+        #[rustfmt::skip]
+        let mut cases = vec![
+            Some(vec![Some(true),  Some(true),  Some(true)]),
+            Some(vec![Some(true),  Some(true),  Some(false)]),
+            Some(vec![Some(true),  Some(true),  None]),
+
+            Some(vec![Some(true),  Some(false), Some(true)]),
+            Some(vec![Some(true),  Some(false), Some(false)]),
+            Some(vec![Some(true),  Some(false), None]),
+
+            Some(vec![Some(true),  None,        Some(true)]),
+            Some(vec![Some(true),  None,        Some(false)]),
+            Some(vec![Some(true),  None,        None]),
+
+            Some(vec![Some(false), Some(true),  Some(true)]),
+            Some(vec![Some(false), Some(true),  Some(false)]),
+            Some(vec![Some(false), Some(true),  None]),
+
+            Some(vec![Some(false), Some(false), Some(true)]),
+            Some(vec![Some(false), Some(false), Some(false)]),
+            Some(vec![Some(false), Some(false), None]),
+
+            Some(vec![Some(false), None,        Some(true)]),
+            Some(vec![Some(false), None,        Some(false)]),
+            Some(vec![Some(false), None,        None]),
+
+            Some(vec![None,        Some(true),  Some(true)]),
+            Some(vec![None,        Some(true),  Some(false)]),
+            Some(vec![None,        Some(true),  None]),
+
+            Some(vec![None,        Some(false), Some(true)]),
+            Some(vec![None,        Some(false), Some(false)]),
+            Some(vec![None,        Some(false), None]),
+
+            Some(vec![None,        None,        Some(true)]),
+            Some(vec![None,        None,        Some(false)]),
+            Some(vec![None,        None,        None]),
+            None,
+        ];
+
+        cases.shuffle(&mut StdRng::seed_from_u64(42));
+
+        // The order is false, true, null
+        #[rustfmt::skip]
+        let expected_descending_false_nulls_first_false = vec![
+            Some(vec![Some(false), Some(false), Some(false)]),
+            Some(vec![Some(false), Some(false), Some(true)]),
+            Some(vec![Some(false), Some(false), None]),
+
+            Some(vec![Some(false), Some(true),  Some(false)]),
+            Some(vec![Some(false), Some(true),  Some(true)]),
+            Some(vec![Some(false), Some(true),  None]),
+
+            Some(vec![Some(false), None,        Some(false)]),
+            Some(vec![Some(false), None,        Some(true)]),
+            Some(vec![Some(false), None,        None]),
+
+            Some(vec![Some(true),  Some(false), Some(false)]),
+            Some(vec![Some(true),  Some(false), Some(true)]),
+            Some(vec![Some(true),  Some(false), None]),
+
+            Some(vec![Some(true),  Some(true),  Some(false)]),
+            Some(vec![Some(true),  Some(true),  Some(true)]),
+            Some(vec![Some(true),  Some(true),  None]),
+
+            Some(vec![Some(true),  None,        Some(false)]),
+            Some(vec![Some(true),  None,        Some(true)]),
+            Some(vec![Some(true),  None,        None]),
+
+            Some(vec![None,        Some(false), Some(false)]),
+            Some(vec![None,        Some(false), Some(true)]),
+            Some(vec![None,        Some(false), None]),
+
+            Some(vec![None,        Some(true),  Some(false)]),
+            Some(vec![None,        Some(true),  Some(true)]),
+            Some(vec![None,        Some(true),  None]),
+
+            Some(vec![None,        None,        Some(false)]),
+            Some(vec![None,        None,        Some(true)]),
+            Some(vec![None,        None,        None]),
+            None,
+        ];
+        test_every_config_sort_boolean_list_arrays(
+            cases.clone(),
+            Some(SortOptions {
+                descending: false,
+                nulls_first: false,
+            }),
+            expected_descending_false_nulls_first_false,
+        );
+
+        // The order is null, false, true
+        #[rustfmt::skip]
+        let expected_descending_false_nulls_first_true = vec![
+            None,
+
+            Some(vec![None,        None,        None]),
+            Some(vec![None,        None,        Some(false)]),
+            Some(vec![None,        None,        Some(true)]),
+
+            Some(vec![None,        Some(false), None]),
+            Some(vec![None,        Some(false), Some(false)]),
+            Some(vec![None,        Some(false), Some(true)]),
+
+            Some(vec![None,        Some(true),  None]),
+            Some(vec![None,        Some(true),  Some(false)]),
+            Some(vec![None,        Some(true),  Some(true)]),
+
+            Some(vec![Some(false), None,        None]),
+            Some(vec![Some(false), None,        Some(false)]),
+            Some(vec![Some(false), None,        Some(true)]),
+
+            Some(vec![Some(false), Some(false), None]),
+            Some(vec![Some(false), Some(false), Some(false)]),
+            Some(vec![Some(false), Some(false), Some(true)]),
+
+            Some(vec![Some(false), Some(true),  None]),
+            Some(vec![Some(false), Some(true),  Some(false)]),
+            Some(vec![Some(false), Some(true),  Some(true)]),
+
+            Some(vec![Some(true),  None,        None]),
+            Some(vec![Some(true),  None,        Some(false)]),
+            Some(vec![Some(true),  None,        Some(true)]),
+
+            Some(vec![Some(true),  Some(false), None]),
+            Some(vec![Some(true),  Some(false), Some(false)]),
+            Some(vec![Some(true),  Some(false), Some(true)]),
+
+            Some(vec![Some(true),  Some(true),  None]),
+            Some(vec![Some(true),  Some(true),  Some(false)]),
+            Some(vec![Some(true),  Some(true),  Some(true)]),
+        ];
+
+        test_every_config_sort_boolean_list_arrays(
+            cases.clone(),
+            Some(SortOptions {
+                descending: false,
+                nulls_first: true,
+            }),
+            expected_descending_false_nulls_first_true,
+        );
+
+        // The order is true, false, null
+        #[rustfmt::skip]
+        let expected_descending_true_nulls_first_false = vec![
+            Some(vec![Some(true),  Some(true),  Some(true)]),
+            Some(vec![Some(true),  Some(true),  Some(false)]),
+            Some(vec![Some(true),  Some(true),  None]),
+
+            Some(vec![Some(true),  Some(false), Some(true)]),
+            Some(vec![Some(true),  Some(false), Some(false)]),
+            Some(vec![Some(true),  Some(false), None]),
+
+            Some(vec![Some(true),  None,        Some(true)]),
+            Some(vec![Some(true),  None,        Some(false)]),
+            Some(vec![Some(true),  None,        None]),
+
+            Some(vec![Some(false), Some(true),  Some(true)]),
+            Some(vec![Some(false), Some(true),  Some(false)]),
+            Some(vec![Some(false), Some(true),  None]),
+
+            Some(vec![Some(false), Some(false), Some(true)]),
+            Some(vec![Some(false), Some(false), Some(false)]),
+            Some(vec![Some(false), Some(false), None]),
+
+            Some(vec![Some(false), None,        Some(true)]),
+            Some(vec![Some(false), None,        Some(false)]),
+            Some(vec![Some(false), None,        None]),
+
+            Some(vec![None,        Some(true),  Some(true)]),
+            Some(vec![None,        Some(true),  Some(false)]),
+            Some(vec![None,        Some(true),  None]),
+
+            Some(vec![None,        Some(false), Some(true)]),
+            Some(vec![None,        Some(false), Some(false)]),
+            Some(vec![None,        Some(false), None]),
+
+            Some(vec![None,        None,        Some(true)]),
+            Some(vec![None,        None,        Some(false)]),
+            Some(vec![None,        None,        None]),
+
+            None,
+        ];
+        test_every_config_sort_boolean_list_arrays(
+            cases.clone(),
+            Some(SortOptions {
+                descending: true,
+                nulls_first: false,
+            }),
+            expected_descending_true_nulls_first_false,
+        );
+
+        // The order is null, true, false
+        #[rustfmt::skip]
+        let expected_descending_true_nulls_first_true = vec![
+            None,
+
+            Some(vec![None,        None,        None]),
+            Some(vec![None,        None,        Some(true)]),
+            Some(vec![None,        None,        Some(false)]),
+
+            Some(vec![None,        Some(true),  None]),
+            Some(vec![None,        Some(true),  Some(true)]),
+            Some(vec![None,        Some(true),  Some(false)]),
+
+            Some(vec![None,        Some(false), None]),
+            Some(vec![None,        Some(false), Some(true)]),
+            Some(vec![None,        Some(false), Some(false)]),
+
+            Some(vec![Some(true),  None,        None]),
+            Some(vec![Some(true),  None,        Some(true)]),
+            Some(vec![Some(true),  None,        Some(false)]),
+
+            Some(vec![Some(true),  Some(true),  None]),
+            Some(vec![Some(true),  Some(true),  Some(true)]),
+            Some(vec![Some(true),  Some(true),  Some(false)]),
+
+            Some(vec![Some(true),  Some(false), None]),
+            Some(vec![Some(true),  Some(false), Some(true)]),
+            Some(vec![Some(true),  Some(false), Some(false)]),
+
+            Some(vec![Some(false), None,        None]),
+            Some(vec![Some(false), None,        Some(true)]),
+            Some(vec![Some(false), None,        Some(false)]),
+
+            Some(vec![Some(false), Some(true),  None]),
+            Some(vec![Some(false), Some(true),  Some(true)]),
+            Some(vec![Some(false), Some(true),  Some(false)]),
+
+            Some(vec![Some(false), Some(false), None]),
+            Some(vec![Some(false), Some(false), Some(true)]),
+            Some(vec![Some(false), Some(false), Some(false)]),
+        ];
+        // Testing with limit false and fixed_length None
+        test_every_config_sort_boolean_list_arrays(
+            cases.clone(),
+            Some(SortOptions {
+                descending: true,
+                nulls_first: true,
+            }),
+            expected_descending_true_nulls_first_true,
+        );
+    }
+
     #[test]
     fn test_sort_indices_decimal128() {
         // decimal default

Reply via email to