This is an automated email from the ASF dual-hosted git repository.
jeffreyvo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 85080638fc feat(arrow-ord): support boolean in `rank` and add tests
for sorting lists of booleans (#6912)
85080638fc is described below
commit 85080638fc223a38950c2d6bda403190037829dc
Author: Raz Luvaton <[email protected]>
AuthorDate: Sat Jan 25 12:26:56 2025 +0200
feat(arrow-ord): support boolean in `rank` and add tests for sorting lists
of booleans (#6912)
* feat(arrow-ord): support boolean in rank
* add tests for sorting list of booleans
* improve boolean rank performance
* fix boolean rank to be sparse rather than dense
* format
* add test for boolean array without nulls
---
arrow-ord/src/rank.rs | 165 +++++++++++++++++++++-
arrow-ord/src/sort.rs | 382 +++++++++++++++++++++++++++++++++++++++++++++++++-
2 files changed, 544 insertions(+), 3 deletions(-)
diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs
index e61cebef38..1b0d2a7e63 100644
--- a/arrow-ord/src/rank.rs
+++ b/arrow-ord/src/rank.rs
@@ -19,7 +19,9 @@
use arrow_array::cast::AsArray;
use arrow_array::types::*;
-use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp,
GenericByteArray};
+use arrow_array::{
+ downcast_primitive_array, Array, ArrowNativeTypeOp, BooleanArray,
GenericByteArray,
+};
use arrow_buffer::NullBuffer;
use arrow_schema::{ArrowError, DataType, SortOptions};
use std::cmp::Ordering;
@@ -29,7 +31,11 @@ pub(crate) fn can_rank(data_type: &DataType) -> bool {
data_type.is_primitive()
|| matches!(
data_type,
- DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary |
DataType::LargeBinary
+ DataType::Boolean
+ | DataType::Utf8
+ | DataType::LargeUtf8
+ | DataType::Binary
+ | DataType::LargeBinary
)
}
@@ -49,6 +55,7 @@ pub fn rank(array: &dyn Array, options: Option<SortOptions>)
-> Result<Vec<u32>,
let options = options.unwrap_or_default();
let ranks = downcast_primitive_array! {
array => primitive_rank(array.values(), array.nulls(), options),
+ DataType::Boolean => boolean_rank(array.as_boolean(), options),
DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(),
options),
DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(),
options),
@@ -135,6 +142,84 @@ where
out
}
+/// Return the index for the rank when ranking boolean array
+///
+/// The index is calculated as follows:
+/// if is_null is true, the index is 2
+/// if is_null is false and the value is true, the index is 1
+/// otherwise, the index is 0
+///
+/// false is 0 and true is 1 because these are the value when cast to number
+#[inline]
+fn get_boolean_rank_index(value: bool, is_null: bool) -> usize {
+ let is_null_num = is_null as usize;
+ (is_null_num << 1) | (value as usize & !is_null_num)
+}
+
+#[inline(never)]
+fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec<u32> {
+ let null_count = array.null_count() as u32;
+ let true_count = array.true_count() as u32;
+ let false_count = array.len() as u32 - null_count - true_count;
+
+ // Rank values for [false, true, null] in that order
+ //
+ // The value for a rank is last value rank + own value count
+ // this means that if we have the following order: `false`, `true` and
then `null`
+ // the ranks will be:
+ // - false: false_count
+ // - true: false_count + true_count
+ // - null: false_count + true_count + null_count
+ //
+ // If we have the following order: `null`, `false` and then `true`
+ // the ranks will be:
+ // - false: null_count + false_count
+ // - true: null_count + false_count + true_count
+ // - null: null_count
+ //
+ // You will notice that the last rank is always the total length of the
array but we don't use it for readability on how the rank is calculated
+ let ranks_index: [u32; 3] = match (options.descending,
options.nulls_first) {
+ // The order is null, true, false
+ (true, true) => [
+ null_count + true_count + false_count,
+ null_count + true_count,
+ null_count,
+ ],
+ // The order is true, false, null
+ (true, false) => [
+ true_count + false_count,
+ true_count,
+ true_count + false_count + null_count,
+ ],
+ // The order is null, false, true
+ (false, true) => [
+ null_count + false_count,
+ null_count + false_count + true_count,
+ null_count,
+ ],
+ // The order is false, true, null
+ (false, false) => [
+ false_count,
+ false_count + true_count,
+ false_count + true_count + null_count,
+ ],
+ };
+
+ match array.nulls().filter(|n| n.null_count() > 0) {
+ Some(n) => array
+ .values()
+ .iter()
+ .zip(n.iter())
+ .map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value,
!is_valid)])
+ .collect::<Vec<u32>>(),
+ None => array
+ .values()
+ .iter()
+ .map(|value| ranks_index[value as usize])
+ .collect::<Vec<u32>>(),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -177,6 +262,82 @@ mod tests {
assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
}
+ #[test]
+ fn test_get_boolean_rank_index() {
+ assert_eq!(get_boolean_rank_index(true, true), 2);
+ assert_eq!(get_boolean_rank_index(false, true), 2);
+ assert_eq!(get_boolean_rank_index(true, false), 1);
+ assert_eq!(get_boolean_rank_index(false, false), 0);
+ }
+
+ #[test]
+ fn test_nullable_booleans() {
+ let descending = SortOptions {
+ descending: true,
+ nulls_first: true,
+ };
+
+ let nulls_last = SortOptions {
+ descending: false,
+ nulls_first: false,
+ };
+
+ let nulls_last_descending = SortOptions {
+ descending: true,
+ nulls_first: false,
+ };
+
+ let a = BooleanArray::from(vec![Some(true), Some(true), None,
Some(false), Some(false)]);
+ let res = rank(&a, None).unwrap();
+ assert_eq!(res, &[5, 5, 1, 3, 3]);
+
+ let res = rank(&a, Some(descending)).unwrap();
+ assert_eq!(res, &[3, 3, 1, 5, 5]);
+
+ let res = rank(&a, Some(nulls_last)).unwrap();
+ assert_eq!(res, &[4, 4, 5, 2, 2]);
+
+ let res = rank(&a, Some(nulls_last_descending)).unwrap();
+ assert_eq!(res, &[2, 2, 5, 4, 4]);
+
+ // Test with non-zero null values
+ let nulls = NullBuffer::from(vec![true, true, false, true, true]);
+ let a = BooleanArray::new(vec![true, true, true, false, false].into(),
Some(nulls));
+ let res = rank(&a, None).unwrap();
+ assert_eq!(res, &[5, 5, 1, 3, 3]);
+ }
+
+ #[test]
+ fn test_booleans() {
+ let descending = SortOptions {
+ descending: true,
+ nulls_first: true,
+ };
+
+ let nulls_last = SortOptions {
+ descending: false,
+ nulls_first: false,
+ };
+
+ let nulls_last_descending = SortOptions {
+ descending: true,
+ nulls_first: false,
+ };
+
+ let a = BooleanArray::from(vec![true, false, false, false, true]);
+ let res = rank(&a, None).unwrap();
+ assert_eq!(res, &[5, 3, 3, 3, 5]);
+
+ let res = rank(&a, Some(descending)).unwrap();
+ assert_eq!(res, &[2, 5, 5, 5, 2]);
+
+ let res = rank(&a, Some(nulls_last)).unwrap();
+ assert_eq!(res, &[5, 3, 3, 3, 5]);
+
+ let res = rank(&a, Some(nulls_last_descending)).unwrap();
+ assert_eq!(res, &[2, 5, 5, 5, 2]);
+ }
+
#[test]
fn test_bytes() {
let v = vec!["foo", "fo", "bar", "bar"];
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index 51a6659e63..fa5e2b8b2f 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -785,12 +785,14 @@ impl LexicographicalComparator {
mod tests {
use super::*;
use arrow_array::builder::{
- FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder,
+ BooleanBuilder, FixedSizeListBuilder, GenericListBuilder,
Int64Builder, ListBuilder,
+ PrimitiveRunBuilder,
};
use arrow_buffer::{i256, NullBuffer};
use arrow_schema::Field;
use half::f16;
use rand::rngs::StdRng;
+ use rand::seq::SliceRandom;
use rand::{Rng, RngCore, SeedableRng};
fn create_decimal128_array(data: Vec<Option<i128>>) -> Decimal128Array {
@@ -1541,6 +1543,384 @@ mod tests {
);
}
+ /// Test sort boolean on each permutation of with/without limit and
GenericListArray/FixedSizeListArray
+ ///
+ /// The input data must have the same length for all list items so that we
can test FixedSizeListArray
+ ///
+ fn test_every_config_sort_boolean_list_arrays(
+ data: Vec<Option<Vec<Option<bool>>>>,
+ options: Option<SortOptions>,
+ expected_data: Vec<Option<Vec<Option<bool>>>>,
+ ) {
+ let first_length = data
+ .iter()
+ .find_map(|x| x.as_ref().map(|x| x.len()))
+ .unwrap_or(0);
+ let first_non_match_length = data
+ .iter()
+ .map(|x| x.as_ref().map(|x| x.len()).unwrap_or(first_length))
+ .position(|x| x != first_length);
+
+ assert_eq!(
+ first_non_match_length, None,
+ "All list items should have the same length {first_length}, input
data is invalid"
+ );
+
+ let first_non_match_length = expected_data
+ .iter()
+ .map(|x| x.as_ref().map(|x| x.len()).unwrap_or(first_length))
+ .position(|x| x != first_length);
+
+ assert_eq!(
+ first_non_match_length, None,
+ "All list items should have the same length {first_length},
expected data is invalid"
+ );
+
+ let limit = expected_data.len().saturating_div(2);
+
+ for &with_limit in &[false, true] {
+ let (limit, expected_data) = if with_limit {
+ (
+ Some(limit),
+ expected_data.iter().take(limit).cloned().collect(),
+ )
+ } else {
+ (None, expected_data.clone())
+ };
+
+ for &fixed_length in &[None, Some(first_length as i32)] {
+ test_sort_boolean_list_arrays(
+ data.clone(),
+ options,
+ limit,
+ expected_data.clone(),
+ fixed_length,
+ );
+ }
+ }
+ }
+
+ fn test_sort_boolean_list_arrays(
+ data: Vec<Option<Vec<Option<bool>>>>,
+ options: Option<SortOptions>,
+ limit: Option<usize>,
+ expected_data: Vec<Option<Vec<Option<bool>>>>,
+ fixed_length: Option<i32>,
+ ) {
+ fn build_fixed_boolean_list_array(
+ data: Vec<Option<Vec<Option<bool>>>>,
+ fixed_length: i32,
+ ) -> ArrayRef {
+ let mut builder = FixedSizeListBuilder::new(
+ BooleanBuilder::with_capacity(fixed_length as usize),
+ fixed_length,
+ );
+ for sublist in data {
+ match sublist {
+ Some(sublist) => {
+ builder.values().extend(sublist);
+ builder.append(true);
+ }
+ None => {
+ builder
+ .values()
+ .extend(std::iter::repeat(None).take(fixed_length
as usize));
+ builder.append(false);
+ }
+ }
+ }
+ Arc::new(builder.finish()) as ArrayRef
+ }
+
+ fn build_generic_boolean_list_array<OffsetSize: OffsetSizeTrait>(
+ data: Vec<Option<Vec<Option<bool>>>>,
+ ) -> ArrayRef {
+ let mut builder = GenericListBuilder::<OffsetSize,
_>::new(BooleanBuilder::new());
+ builder.extend(data);
+ Arc::new(builder.finish()) as ArrayRef
+ }
+
+ // for FixedSizedList
+ if let Some(length) = fixed_length {
+ let input = build_fixed_boolean_list_array(data.clone(), length);
+ let sorted = match limit {
+ Some(_) => sort_limit(&(input as ArrayRef), options,
limit).unwrap(),
+ _ => sort(&(input as ArrayRef), options).unwrap(),
+ };
+ let expected =
build_fixed_boolean_list_array(expected_data.clone(), length);
+
+ assert_eq!(&sorted, &expected);
+ }
+
+ // for List
+ let input = build_generic_boolean_list_array::<i32>(data.clone());
+ let sorted = match limit {
+ Some(_) => sort_limit(&(input as ArrayRef), options,
limit).unwrap(),
+ _ => sort(&(input as ArrayRef), options).unwrap(),
+ };
+ let expected =
build_generic_boolean_list_array::<i32>(expected_data.clone());
+
+ assert_eq!(&sorted, &expected);
+
+ // for LargeList
+ let input = build_generic_boolean_list_array::<i64>(data.clone());
+ let sorted = match limit {
+ Some(_) => sort_limit(&(input as ArrayRef), options,
limit).unwrap(),
+ _ => sort(&(input as ArrayRef), options).unwrap(),
+ };
+ let expected =
build_generic_boolean_list_array::<i64>(expected_data.clone());
+
+ assert_eq!(&sorted, &expected);
+ }
+
+ #[test]
+ fn test_sort_list_of_booleans() {
+ // These are all the possible combinations of boolean values
+ // There are 3^3 + 1 = 28 possible combinations (3 values to permutate
- [true, false, null] and 1 None value)
+ #[rustfmt::skip]
+ let mut cases = vec![
+ Some(vec![Some(true), Some(true), Some(true)]),
+ Some(vec![Some(true), Some(true), Some(false)]),
+ Some(vec![Some(true), Some(true), None]),
+
+ Some(vec![Some(true), Some(false), Some(true)]),
+ Some(vec![Some(true), Some(false), Some(false)]),
+ Some(vec![Some(true), Some(false), None]),
+
+ Some(vec![Some(true), None, Some(true)]),
+ Some(vec![Some(true), None, Some(false)]),
+ Some(vec![Some(true), None, None]),
+
+ Some(vec![Some(false), Some(true), Some(true)]),
+ Some(vec![Some(false), Some(true), Some(false)]),
+ Some(vec![Some(false), Some(true), None]),
+
+ Some(vec![Some(false), Some(false), Some(true)]),
+ Some(vec![Some(false), Some(false), Some(false)]),
+ Some(vec![Some(false), Some(false), None]),
+
+ Some(vec![Some(false), None, Some(true)]),
+ Some(vec![Some(false), None, Some(false)]),
+ Some(vec![Some(false), None, None]),
+
+ Some(vec![None, Some(true), Some(true)]),
+ Some(vec![None, Some(true), Some(false)]),
+ Some(vec![None, Some(true), None]),
+
+ Some(vec![None, Some(false), Some(true)]),
+ Some(vec![None, Some(false), Some(false)]),
+ Some(vec![None, Some(false), None]),
+
+ Some(vec![None, None, Some(true)]),
+ Some(vec![None, None, Some(false)]),
+ Some(vec![None, None, None]),
+ None,
+ ];
+
+ cases.shuffle(&mut StdRng::seed_from_u64(42));
+
+ // The order is false, true, null
+ #[rustfmt::skip]
+ let expected_descending_false_nulls_first_false = vec![
+ Some(vec![Some(false), Some(false), Some(false)]),
+ Some(vec![Some(false), Some(false), Some(true)]),
+ Some(vec![Some(false), Some(false), None]),
+
+ Some(vec![Some(false), Some(true), Some(false)]),
+ Some(vec![Some(false), Some(true), Some(true)]),
+ Some(vec![Some(false), Some(true), None]),
+
+ Some(vec![Some(false), None, Some(false)]),
+ Some(vec![Some(false), None, Some(true)]),
+ Some(vec![Some(false), None, None]),
+
+ Some(vec![Some(true), Some(false), Some(false)]),
+ Some(vec![Some(true), Some(false), Some(true)]),
+ Some(vec![Some(true), Some(false), None]),
+
+ Some(vec![Some(true), Some(true), Some(false)]),
+ Some(vec![Some(true), Some(true), Some(true)]),
+ Some(vec![Some(true), Some(true), None]),
+
+ Some(vec![Some(true), None, Some(false)]),
+ Some(vec![Some(true), None, Some(true)]),
+ Some(vec![Some(true), None, None]),
+
+ Some(vec![None, Some(false), Some(false)]),
+ Some(vec![None, Some(false), Some(true)]),
+ Some(vec![None, Some(false), None]),
+
+ Some(vec![None, Some(true), Some(false)]),
+ Some(vec![None, Some(true), Some(true)]),
+ Some(vec![None, Some(true), None]),
+
+ Some(vec![None, None, Some(false)]),
+ Some(vec![None, None, Some(true)]),
+ Some(vec![None, None, None]),
+ None,
+ ];
+ test_every_config_sort_boolean_list_arrays(
+ cases.clone(),
+ Some(SortOptions {
+ descending: false,
+ nulls_first: false,
+ }),
+ expected_descending_false_nulls_first_false,
+ );
+
+ // The order is null, false, true
+ #[rustfmt::skip]
+ let expected_descending_false_nulls_first_true = vec![
+ None,
+
+ Some(vec![None, None, None]),
+ Some(vec![None, None, Some(false)]),
+ Some(vec![None, None, Some(true)]),
+
+ Some(vec![None, Some(false), None]),
+ Some(vec![None, Some(false), Some(false)]),
+ Some(vec![None, Some(false), Some(true)]),
+
+ Some(vec![None, Some(true), None]),
+ Some(vec![None, Some(true), Some(false)]),
+ Some(vec![None, Some(true), Some(true)]),
+
+ Some(vec![Some(false), None, None]),
+ Some(vec![Some(false), None, Some(false)]),
+ Some(vec![Some(false), None, Some(true)]),
+
+ Some(vec![Some(false), Some(false), None]),
+ Some(vec![Some(false), Some(false), Some(false)]),
+ Some(vec![Some(false), Some(false), Some(true)]),
+
+ Some(vec![Some(false), Some(true), None]),
+ Some(vec![Some(false), Some(true), Some(false)]),
+ Some(vec![Some(false), Some(true), Some(true)]),
+
+ Some(vec![Some(true), None, None]),
+ Some(vec![Some(true), None, Some(false)]),
+ Some(vec![Some(true), None, Some(true)]),
+
+ Some(vec![Some(true), Some(false), None]),
+ Some(vec![Some(true), Some(false), Some(false)]),
+ Some(vec![Some(true), Some(false), Some(true)]),
+
+ Some(vec![Some(true), Some(true), None]),
+ Some(vec![Some(true), Some(true), Some(false)]),
+ Some(vec![Some(true), Some(true), Some(true)]),
+ ];
+
+ test_every_config_sort_boolean_list_arrays(
+ cases.clone(),
+ Some(SortOptions {
+ descending: false,
+ nulls_first: true,
+ }),
+ expected_descending_false_nulls_first_true,
+ );
+
+ // The order is true, false, null
+ #[rustfmt::skip]
+ let expected_descending_true_nulls_first_false = vec![
+ Some(vec![Some(true), Some(true), Some(true)]),
+ Some(vec![Some(true), Some(true), Some(false)]),
+ Some(vec![Some(true), Some(true), None]),
+
+ Some(vec![Some(true), Some(false), Some(true)]),
+ Some(vec![Some(true), Some(false), Some(false)]),
+ Some(vec![Some(true), Some(false), None]),
+
+ Some(vec![Some(true), None, Some(true)]),
+ Some(vec![Some(true), None, Some(false)]),
+ Some(vec![Some(true), None, None]),
+
+ Some(vec![Some(false), Some(true), Some(true)]),
+ Some(vec![Some(false), Some(true), Some(false)]),
+ Some(vec![Some(false), Some(true), None]),
+
+ Some(vec![Some(false), Some(false), Some(true)]),
+ Some(vec![Some(false), Some(false), Some(false)]),
+ Some(vec![Some(false), Some(false), None]),
+
+ Some(vec![Some(false), None, Some(true)]),
+ Some(vec![Some(false), None, Some(false)]),
+ Some(vec![Some(false), None, None]),
+
+ Some(vec![None, Some(true), Some(true)]),
+ Some(vec![None, Some(true), Some(false)]),
+ Some(vec![None, Some(true), None]),
+
+ Some(vec![None, Some(false), Some(true)]),
+ Some(vec![None, Some(false), Some(false)]),
+ Some(vec![None, Some(false), None]),
+
+ Some(vec![None, None, Some(true)]),
+ Some(vec![None, None, Some(false)]),
+ Some(vec![None, None, None]),
+
+ None,
+ ];
+ test_every_config_sort_boolean_list_arrays(
+ cases.clone(),
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ expected_descending_true_nulls_first_false,
+ );
+
+ // The order is null, true, false
+ #[rustfmt::skip]
+ let expected_descending_true_nulls_first_true = vec![
+ None,
+
+ Some(vec![None, None, None]),
+ Some(vec![None, None, Some(true)]),
+ Some(vec![None, None, Some(false)]),
+
+ Some(vec![None, Some(true), None]),
+ Some(vec![None, Some(true), Some(true)]),
+ Some(vec![None, Some(true), Some(false)]),
+
+ Some(vec![None, Some(false), None]),
+ Some(vec![None, Some(false), Some(true)]),
+ Some(vec![None, Some(false), Some(false)]),
+
+ Some(vec![Some(true), None, None]),
+ Some(vec![Some(true), None, Some(true)]),
+ Some(vec![Some(true), None, Some(false)]),
+
+ Some(vec![Some(true), Some(true), None]),
+ Some(vec![Some(true), Some(true), Some(true)]),
+ Some(vec![Some(true), Some(true), Some(false)]),
+
+ Some(vec![Some(true), Some(false), None]),
+ Some(vec![Some(true), Some(false), Some(true)]),
+ Some(vec![Some(true), Some(false), Some(false)]),
+
+ Some(vec![Some(false), None, None]),
+ Some(vec![Some(false), None, Some(true)]),
+ Some(vec![Some(false), None, Some(false)]),
+
+ Some(vec![Some(false), Some(true), None]),
+ Some(vec![Some(false), Some(true), Some(true)]),
+ Some(vec![Some(false), Some(true), Some(false)]),
+
+ Some(vec![Some(false), Some(false), None]),
+ Some(vec![Some(false), Some(false), Some(true)]),
+ Some(vec![Some(false), Some(false), Some(false)]),
+ ];
+ // Testing with limit false and fixed_length None
+ test_every_config_sort_boolean_list_arrays(
+ cases.clone(),
+ Some(SortOptions {
+ descending: true,
+ nulls_first: true,
+ }),
+ expected_descending_true_nulls_first_true,
+ );
+ }
+
#[test]
fn test_sort_indices_decimal128() {
// decimal default