This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new c5ab64cdc16 fix: lexsort_to_indices unsupported mixed types with list
(#5455)
c5ab64cdc16 is described below
commit c5ab64cdc16b2b37943a865f2e4e8b832d24cdf3
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Mar 2 16:14:19 2024 -0500
fix: lexsort_to_indices unsupported mixed types with list (#5455)
* fix: lexsort_to_indices unsupported mixed types with list
* chore: pass clippy
---------
Co-authored-by: JasonLi <[email protected]>
---
arrow-ord/src/sort.rs | 313 +++++++++++++++++++++++++++++++++++++++++++++--
arrow/benches/lexsort.rs | 59 +++++++++
2 files changed, 363 insertions(+), 9 deletions(-)
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index c25df3a480b..2c06057a84e 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -766,18 +766,61 @@ impl LexicographicalComparator {
pub fn try_new(columns: &[SortColumn]) ->
Result<LexicographicalComparator, ArrowError> {
let compare_items = columns
.iter()
- .map(|column| {
- // flatten and convert build comparators
- let values = column.values.as_ref();
- Ok((
- values.logical_nulls(),
- build_compare(values, values)?,
- column.options.unwrap_or_default(),
- ))
- })
+ .map(Self::build_compare_item)
.collect::<Result<Vec<_>, ArrowError>>()?;
Ok(LexicographicalComparator { compare_items })
}
+
+ fn build_compare_item(column: &SortColumn) ->
Result<LexicographicalCompareItem, ArrowError> {
+ let values = column.values.as_ref();
+ let options = column.options.unwrap_or_default();
+ let comparator = match values.data_type() {
+ DataType::List(_) =>
Self::build_list_compare(values.as_list::<i32>(), options)?,
+ DataType::LargeList(_) =>
Self::build_list_compare(values.as_list::<i64>(), options)?,
+ DataType::FixedSizeList(_, _) => {
+
Self::build_fixed_size_list_compare(values.as_fixed_size_list(), options)?
+ }
+ _ => build_compare(values, values)?,
+ };
+ Ok((values.logical_nulls(), comparator, options))
+ }
+
+ fn build_list_compare<O: OffsetSizeTrait>(
+ array: &GenericListArray<O>,
+ options: SortOptions,
+ ) -> Result<DynComparator, ArrowError> {
+ let rank = child_rank(array.values().as_ref(), options)?;
+ let offsets = array.offsets().clone();
+ let cmp = Box::new(move |i: usize, j: usize| {
+ macro_rules! nth_value {
+ ($INDEX:expr) => {{
+ let end = offsets[$INDEX + 1].as_usize();
+ let start = offsets[$INDEX].as_usize();
+ &rank[start..end]
+ }};
+ }
+ Ord::cmp(nth_value!(i), nth_value!(j))
+ });
+ Ok(cmp)
+ }
+
+ fn build_fixed_size_list_compare(
+ array: &FixedSizeListArray,
+ options: SortOptions,
+ ) -> Result<DynComparator, ArrowError> {
+ let rank = child_rank(array.values().as_ref(), options)?;
+ let size = array.value_length() as usize;
+ let cmp = Box::new(move |i: usize, j: usize| {
+ macro_rules! nth_value {
+ ($INDEX:expr) => {{
+ let start = $INDEX * size;
+ &rank[start..start + size]
+ }};
+ }
+ Ord::cmp(nth_value!(i), nth_value!(j))
+ });
+ Ok(cmp)
+ }
}
#[cfg(test)]
@@ -3592,6 +3635,258 @@ mod tests {
// Limiting by more rows than present is ok
test_lex_sort_arrays(input, slice_arrays(expected, 0, 5), Some(10));
+
+ // test with FixedSizeListArray, arrays order: [UInt32,
FixedSizeList(UInt32, 1)]
+
+ // case1
+ let primitive_array_data = vec![
+ Some(2),
+ Some(3),
+ Some(2),
+ Some(0),
+ None,
+ Some(2),
+ Some(1),
+ Some(2),
+ ];
+ let list_array_data = vec![
+ None,
+ Some(vec![Some(4)]),
+ Some(vec![Some(3)]),
+ Some(vec![Some(1)]),
+ Some(vec![Some(5)]),
+ Some(vec![Some(0)]),
+ Some(vec![Some(2)]),
+ Some(vec![None]),
+ ];
+
+ let expected_primitive_array_data = vec![
+ None,
+ Some(0),
+ Some(1),
+ Some(2),
+ Some(2),
+ Some(2),
+ Some(2),
+ Some(3),
+ ];
+ let expected_list_array_data = vec![
+ Some(vec![Some(5)]),
+ Some(vec![Some(1)]),
+ Some(vec![Some(2)]),
+ None, // <-
+ Some(vec![None]),
+ Some(vec![Some(0)]),
+ Some(vec![Some(3)]), // <-
+ Some(vec![Some(4)]),
+ ];
+ test_lex_sort_mixed_types_with_fixed_size_list::<Int32Type>(
+ primitive_array_data.clone(),
+ list_array_data.clone(),
+ expected_primitive_array_data.clone(),
+ expected_list_array_data,
+ None,
+ None,
+ );
+
+ // case2
+ let primitive_array_options = SortOptions {
+ descending: false,
+ nulls_first: true,
+ };
+ let list_array_options = SortOptions {
+ descending: false,
+ nulls_first: false, // has been modified
+ };
+ let expected_list_array_data = vec![
+ Some(vec![Some(5)]),
+ Some(vec![Some(1)]),
+ Some(vec![Some(2)]),
+ Some(vec![Some(0)]), // <-
+ Some(vec![Some(3)]),
+ Some(vec![None]),
+ None, // <-
+ Some(vec![Some(4)]),
+ ];
+ test_lex_sort_mixed_types_with_fixed_size_list::<Int32Type>(
+ primitive_array_data.clone(),
+ list_array_data.clone(),
+ expected_primitive_array_data.clone(),
+ expected_list_array_data,
+ Some(primitive_array_options),
+ Some(list_array_options),
+ );
+
+ // case3
+ let primitive_array_options = SortOptions {
+ descending: false,
+ nulls_first: true,
+ };
+ let list_array_options = SortOptions {
+ descending: true, // has been modified
+ nulls_first: true,
+ };
+ let expected_list_array_data = vec![
+ Some(vec![Some(5)]),
+ Some(vec![Some(1)]),
+ Some(vec![Some(2)]),
+ None, // <-
+ Some(vec![None]),
+ Some(vec![Some(3)]),
+ Some(vec![Some(0)]), // <-
+ Some(vec![Some(4)]),
+ ];
+ test_lex_sort_mixed_types_with_fixed_size_list::<Int32Type>(
+ primitive_array_data.clone(),
+ list_array_data.clone(),
+ expected_primitive_array_data,
+ expected_list_array_data,
+ Some(primitive_array_options),
+ Some(list_array_options),
+ );
+
+ // test with ListArray/LargeListArray, arrays order:
[List<UInt32>/LargeList<UInt32>, UInt32]
+
+ let list_array_data = vec![
+ Some(vec![Some(2), Some(1)]), // 0
+ None, // 10
+ Some(vec![Some(3)]), // 1
+ Some(vec![Some(2), Some(0)]), // 2
+ Some(vec![None, Some(2)]), // 3
+ Some(vec![Some(0)]), // none
+ None, // 11
+ Some(vec![Some(2), None]), // 4
+ Some(vec![None]), // 5
+ Some(vec![Some(2), Some(1)]), // 6
+ ];
+ let primitive_array_data = vec![
+ Some(0),
+ Some(10),
+ Some(1),
+ Some(2),
+ Some(3),
+ None,
+ Some(11),
+ Some(4),
+ Some(5),
+ Some(6),
+ ];
+ let expected_list_array_data = vec![
+ None,
+ None,
+ Some(vec![None]),
+ Some(vec![None, Some(2)]),
+ Some(vec![Some(0)]),
+ Some(vec![Some(2), None]),
+ Some(vec![Some(2), Some(0)]),
+ Some(vec![Some(2), Some(1)]),
+ Some(vec![Some(2), Some(1)]),
+ Some(vec![Some(3)]),
+ ];
+ let expected_primitive_array_data = vec![
+ Some(10),
+ Some(11),
+ Some(5),
+ Some(3),
+ None,
+ Some(4),
+ Some(2),
+ Some(0),
+ Some(6),
+ Some(1),
+ ];
+ test_lex_sort_mixed_types_with_list::<Int32Type>(
+ list_array_data.clone(),
+ primitive_array_data.clone(),
+ expected_list_array_data,
+ expected_primitive_array_data,
+ None,
+ None,
+ );
+ }
+
+ fn test_lex_sort_mixed_types_with_fixed_size_list<T>(
+ primitive_array_data: Vec<Option<T::Native>>,
+ list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+ expected_primitive_array_data: Vec<Option<T::Native>>,
+ expected_list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+ primitive_array_options: Option<SortOptions>,
+ list_array_options: Option<SortOptions>,
+ ) where
+ T: ArrowPrimitiveType,
+ PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
+ {
+ let input = vec![
+ SortColumn {
+ values:
Arc::new(PrimitiveArray::<T>::from(primitive_array_data.clone()))
+ as ArrayRef,
+ options: primitive_array_options,
+ },
+ SortColumn {
+ values: Arc::new(FixedSizeListArray::from_iter_primitive::<T,
_, _>(
+ list_array_data.clone(),
+ 1,
+ )) as ArrayRef,
+ options: list_array_options,
+ },
+ ];
+
+ let expected = vec![
+ Arc::new(PrimitiveArray::<T>::from(
+ expected_primitive_array_data.clone(),
+ )) as ArrayRef,
+ Arc::new(FixedSizeListArray::from_iter_primitive::<T, _, _>(
+ expected_list_array_data.clone(),
+ 1,
+ )) as ArrayRef,
+ ];
+
+ test_lex_sort_arrays(input.clone(), expected.clone(), None);
+ test_lex_sort_arrays(input.clone(), slice_arrays(expected.clone(), 0,
5), Some(5));
+ }
+
+ fn test_lex_sort_mixed_types_with_list<T>(
+ list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+ primitive_array_data: Vec<Option<T::Native>>,
+ expected_list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+ expected_primitive_array_data: Vec<Option<T::Native>>,
+ list_array_options: Option<SortOptions>,
+ primitive_array_options: Option<SortOptions>,
+ ) where
+ T: ArrowPrimitiveType,
+ PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
+ {
+ macro_rules! run_test {
+ ($ARRAY_TYPE:ident) => {
+ let input = vec![
+ SortColumn {
+ values:
Arc::new(<$ARRAY_TYPE>::from_iter_primitive::<T, _, _>(
+ list_array_data.clone(),
+ )) as ArrayRef,
+ options: list_array_options.clone(),
+ },
+ SortColumn {
+ values:
Arc::new(PrimitiveArray::<T>::from(primitive_array_data.clone()))
+ as ArrayRef,
+ options: primitive_array_options.clone(),
+ },
+ ];
+
+ let expected = vec![
+ Arc::new(<$ARRAY_TYPE>::from_iter_primitive::<T, _, _>(
+ expected_list_array_data.clone(),
+ )) as ArrayRef,
+ Arc::new(PrimitiveArray::<T>::from(
+ expected_primitive_array_data.clone(),
+ )) as ArrayRef,
+ ];
+
+ test_lex_sort_arrays(input.clone(), expected.clone(), None);
+ test_lex_sort_arrays(input.clone(),
slice_arrays(expected.clone(), 0, 5), Some(5));
+ };
+ }
+ run_test!(ListArray);
+ run_test!(LargeListArray);
}
#[test]
diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs
index bd2db1e5022..cd952299df4 100644
--- a/arrow/benches/lexsort.rs
+++ b/arrow/benches/lexsort.rs
@@ -20,8 +20,10 @@ use arrow::row::{RowConverter, SortField};
use arrow::util::bench_util::{
create_dict_from_values, create_primitive_array,
create_string_array_with_len,
};
+use arrow::util::data_gen::create_random_array;
use arrow_array::types::Int32Type;
use arrow_array::{Array, ArrayRef, UInt32Array};
+use arrow_schema::{DataType, Field};
use criterion::{criterion_group, criterion_main, Criterion};
use std::sync::Arc;
@@ -33,6 +35,10 @@ enum Column {
Optional16CharString,
Optional50CharString,
Optional100Value50CharStringDict,
+ RequiredI32List,
+ OptionalI32List,
+ Required4CharStringList,
+ Optional4CharStringList,
}
impl std::fmt::Debug for Column {
@@ -44,6 +50,10 @@ impl std::fmt::Debug for Column {
Column::Optional16CharString => "str_opt(16)",
Column::Optional50CharString => "str_opt(50)",
Column::Optional100Value50CharStringDict =>
"dict(100,str_opt(50))",
+ Column::RequiredI32List => "i32_list",
+ Column::OptionalI32List => "i32_list_opt",
+ Column::Required4CharStringList => "str_list(4)",
+ Column::Optional4CharStringList => "str_list_opt(4)",
};
f.write_str(s)
}
@@ -70,6 +80,38 @@ impl Column {
&create_string_array_with_len::<i32>(100, 0., 50),
))
}
+ Column::RequiredI32List => {
+ let field = Field::new(
+ "_1",
+ DataType::List(Arc::new(Field::new("item",
DataType::Int32, false))),
+ true,
+ );
+ create_random_array(&field, size, 0., 1.).unwrap()
+ }
+ Column::OptionalI32List => {
+ let field = Field::new(
+ "_1",
+ DataType::List(Arc::new(Field::new("item",
DataType::Int32, true))),
+ true,
+ );
+ create_random_array(&field, size, 0.2, 1.).unwrap()
+ }
+ Column::Required4CharStringList => {
+ let field = Field::new(
+ "_1",
+ DataType::List(Arc::new(Field::new("item", DataType::Utf8,
false))),
+ true,
+ );
+ create_random_array(&field, size, 0., 1.).unwrap()
+ }
+ Column::Optional4CharStringList => {
+ let field = Field::new(
+ "_1",
+ DataType::List(Arc::new(Field::new("item", DataType::Utf8,
true))),
+ true,
+ );
+ create_random_array(&field, size, 0.2, 1.).unwrap()
+ }
}
}
}
@@ -150,6 +192,23 @@ fn add_benchmark(c: &mut Criterion) {
Column::Optional100Value50CharStringDict,
Column::Optional50CharString,
],
+ &[Column::OptionalI32, Column::RequiredI32List],
+ &[Column::OptionalI32, Column::OptionalI32List],
+ &[Column::OptionalI32List, Column::OptionalI32],
+ &[Column::RequiredI32, Column::Required4CharStringList],
+ &[Column::Required4CharStringList, Column::RequiredI32],
+ &[Column::RequiredI32, Column::Optional4CharStringList],
+ &[Column::Optional4CharStringList, Column::RequiredI32],
+ &[
+ Column::RequiredI32,
+ Column::RequiredI32List,
+ Column::Required16CharString,
+ ],
+ &[
+ Column::OptionalI32,
+ Column::OptionalI32List,
+ Column::Optional50CharString,
+ ],
];
for case in cases {