This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new c5ab64cdc16 fix: lexsort_to_indices unsupported mixed types with list 
(#5455)
c5ab64cdc16 is described below

commit c5ab64cdc16b2b37943a865f2e4e8b832d24cdf3
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Mar 2 16:14:19 2024 -0500

    fix: lexsort_to_indices unsupported mixed types with list (#5455)
    
    * fix: lexsort_to_indices unsupported mixed types with list
    
    * chore: pass clippy
    
    ---------
    
    Co-authored-by: JasonLi <[email protected]>
---
 arrow-ord/src/sort.rs    | 313 +++++++++++++++++++++++++++++++++++++++++++++--
 arrow/benches/lexsort.rs |  59 +++++++++
 2 files changed, 363 insertions(+), 9 deletions(-)

diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index c25df3a480b..2c06057a84e 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -766,18 +766,61 @@ impl LexicographicalComparator {
     pub fn try_new(columns: &[SortColumn]) -> 
Result<LexicographicalComparator, ArrowError> {
         let compare_items = columns
             .iter()
-            .map(|column| {
-                // flatten and convert build comparators
-                let values = column.values.as_ref();
-                Ok((
-                    values.logical_nulls(),
-                    build_compare(values, values)?,
-                    column.options.unwrap_or_default(),
-                ))
-            })
+            .map(Self::build_compare_item)
             .collect::<Result<Vec<_>, ArrowError>>()?;
         Ok(LexicographicalComparator { compare_items })
     }
+
+    fn build_compare_item(column: &SortColumn) -> 
Result<LexicographicalCompareItem, ArrowError> {
+        let values = column.values.as_ref();
+        let options = column.options.unwrap_or_default();
+        let comparator = match values.data_type() {
+            DataType::List(_) => 
Self::build_list_compare(values.as_list::<i32>(), options)?,
+            DataType::LargeList(_) => 
Self::build_list_compare(values.as_list::<i64>(), options)?,
+            DataType::FixedSizeList(_, _) => {
+                
Self::build_fixed_size_list_compare(values.as_fixed_size_list(), options)?
+            }
+            _ => build_compare(values, values)?,
+        };
+        Ok((values.logical_nulls(), comparator, options))
+    }
+
+    fn build_list_compare<O: OffsetSizeTrait>(
+        array: &GenericListArray<O>,
+        options: SortOptions,
+    ) -> Result<DynComparator, ArrowError> {
+        let rank = child_rank(array.values().as_ref(), options)?;
+        let offsets = array.offsets().clone();
+        let cmp = Box::new(move |i: usize, j: usize| {
+            macro_rules! nth_value {
+                ($INDEX:expr) => {{
+                    let end = offsets[$INDEX + 1].as_usize();
+                    let start = offsets[$INDEX].as_usize();
+                    &rank[start..end]
+                }};
+            }
+            Ord::cmp(nth_value!(i), nth_value!(j))
+        });
+        Ok(cmp)
+    }
+
+    fn build_fixed_size_list_compare(
+        array: &FixedSizeListArray,
+        options: SortOptions,
+    ) -> Result<DynComparator, ArrowError> {
+        let rank = child_rank(array.values().as_ref(), options)?;
+        let size = array.value_length() as usize;
+        let cmp = Box::new(move |i: usize, j: usize| {
+            macro_rules! nth_value {
+                ($INDEX:expr) => {{
+                    let start = $INDEX * size;
+                    &rank[start..start + size]
+                }};
+            }
+            Ord::cmp(nth_value!(i), nth_value!(j))
+        });
+        Ok(cmp)
+    }
 }
 
 #[cfg(test)]
@@ -3592,6 +3635,258 @@ mod tests {
 
         // Limiting by more rows than present is ok
         test_lex_sort_arrays(input, slice_arrays(expected, 0, 5), Some(10));
+
+        // test with FixedSizeListArray, arrays order: [UInt32, 
FixedSizeList(UInt32, 1)]
+
+        // case1
+        let primitive_array_data = vec![
+            Some(2),
+            Some(3),
+            Some(2),
+            Some(0),
+            None,
+            Some(2),
+            Some(1),
+            Some(2),
+        ];
+        let list_array_data = vec![
+            None,
+            Some(vec![Some(4)]),
+            Some(vec![Some(3)]),
+            Some(vec![Some(1)]),
+            Some(vec![Some(5)]),
+            Some(vec![Some(0)]),
+            Some(vec![Some(2)]),
+            Some(vec![None]),
+        ];
+
+        let expected_primitive_array_data = vec![
+            None,
+            Some(0),
+            Some(1),
+            Some(2),
+            Some(2),
+            Some(2),
+            Some(2),
+            Some(3),
+        ];
+        let expected_list_array_data = vec![
+            Some(vec![Some(5)]),
+            Some(vec![Some(1)]),
+            Some(vec![Some(2)]),
+            None, // <-
+            Some(vec![None]),
+            Some(vec![Some(0)]),
+            Some(vec![Some(3)]), // <-
+            Some(vec![Some(4)]),
+        ];
+        test_lex_sort_mixed_types_with_fixed_size_list::<Int32Type>(
+            primitive_array_data.clone(),
+            list_array_data.clone(),
+            expected_primitive_array_data.clone(),
+            expected_list_array_data,
+            None,
+            None,
+        );
+
+        // case2
+        let primitive_array_options = SortOptions {
+            descending: false,
+            nulls_first: true,
+        };
+        let list_array_options = SortOptions {
+            descending: false,
+            nulls_first: false, // has been modified
+        };
+        let expected_list_array_data = vec![
+            Some(vec![Some(5)]),
+            Some(vec![Some(1)]),
+            Some(vec![Some(2)]),
+            Some(vec![Some(0)]), // <-
+            Some(vec![Some(3)]),
+            Some(vec![None]),
+            None, // <-
+            Some(vec![Some(4)]),
+        ];
+        test_lex_sort_mixed_types_with_fixed_size_list::<Int32Type>(
+            primitive_array_data.clone(),
+            list_array_data.clone(),
+            expected_primitive_array_data.clone(),
+            expected_list_array_data,
+            Some(primitive_array_options),
+            Some(list_array_options),
+        );
+
+        // case3
+        let primitive_array_options = SortOptions {
+            descending: false,
+            nulls_first: true,
+        };
+        let list_array_options = SortOptions {
+            descending: true, // has been modified
+            nulls_first: true,
+        };
+        let expected_list_array_data = vec![
+            Some(vec![Some(5)]),
+            Some(vec![Some(1)]),
+            Some(vec![Some(2)]),
+            None, // <-
+            Some(vec![None]),
+            Some(vec![Some(3)]),
+            Some(vec![Some(0)]), // <-
+            Some(vec![Some(4)]),
+        ];
+        test_lex_sort_mixed_types_with_fixed_size_list::<Int32Type>(
+            primitive_array_data.clone(),
+            list_array_data.clone(),
+            expected_primitive_array_data,
+            expected_list_array_data,
+            Some(primitive_array_options),
+            Some(list_array_options),
+        );
+
+        // test with ListArray/LargeListArray, arrays order: 
[List<UInt32>/LargeList<UInt32>, UInt32]
+
+        let list_array_data = vec![
+            Some(vec![Some(2), Some(1)]), // 0
+            None,                         // 10
+            Some(vec![Some(3)]),          // 1
+            Some(vec![Some(2), Some(0)]), // 2
+            Some(vec![None, Some(2)]),    // 3
+            Some(vec![Some(0)]),          // none
+            None,                         // 11
+            Some(vec![Some(2), None]),    // 4
+            Some(vec![None]),             // 5
+            Some(vec![Some(2), Some(1)]), // 6
+        ];
+        let primitive_array_data = vec![
+            Some(0),
+            Some(10),
+            Some(1),
+            Some(2),
+            Some(3),
+            None,
+            Some(11),
+            Some(4),
+            Some(5),
+            Some(6),
+        ];
+        let expected_list_array_data = vec![
+            None,
+            None,
+            Some(vec![None]),
+            Some(vec![None, Some(2)]),
+            Some(vec![Some(0)]),
+            Some(vec![Some(2), None]),
+            Some(vec![Some(2), Some(0)]),
+            Some(vec![Some(2), Some(1)]),
+            Some(vec![Some(2), Some(1)]),
+            Some(vec![Some(3)]),
+        ];
+        let expected_primitive_array_data = vec![
+            Some(10),
+            Some(11),
+            Some(5),
+            Some(3),
+            None,
+            Some(4),
+            Some(2),
+            Some(0),
+            Some(6),
+            Some(1),
+        ];
+        test_lex_sort_mixed_types_with_list::<Int32Type>(
+            list_array_data.clone(),
+            primitive_array_data.clone(),
+            expected_list_array_data,
+            expected_primitive_array_data,
+            None,
+            None,
+        );
+    }
+
+    fn test_lex_sort_mixed_types_with_fixed_size_list<T>(
+        primitive_array_data: Vec<Option<T::Native>>,
+        list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+        expected_primitive_array_data: Vec<Option<T::Native>>,
+        expected_list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+        primitive_array_options: Option<SortOptions>,
+        list_array_options: Option<SortOptions>,
+    ) where
+        T: ArrowPrimitiveType,
+        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
+    {
+        let input = vec![
+            SortColumn {
+                values: 
Arc::new(PrimitiveArray::<T>::from(primitive_array_data.clone()))
+                    as ArrayRef,
+                options: primitive_array_options,
+            },
+            SortColumn {
+                values: Arc::new(FixedSizeListArray::from_iter_primitive::<T, 
_, _>(
+                    list_array_data.clone(),
+                    1,
+                )) as ArrayRef,
+                options: list_array_options,
+            },
+        ];
+
+        let expected = vec![
+            Arc::new(PrimitiveArray::<T>::from(
+                expected_primitive_array_data.clone(),
+            )) as ArrayRef,
+            Arc::new(FixedSizeListArray::from_iter_primitive::<T, _, _>(
+                expected_list_array_data.clone(),
+                1,
+            )) as ArrayRef,
+        ];
+
+        test_lex_sort_arrays(input.clone(), expected.clone(), None);
+        test_lex_sort_arrays(input.clone(), slice_arrays(expected.clone(), 0, 
5), Some(5));
+    }
+
+    fn test_lex_sort_mixed_types_with_list<T>(
+        list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+        primitive_array_data: Vec<Option<T::Native>>,
+        expected_list_array_data: Vec<Option<Vec<Option<T::Native>>>>,
+        expected_primitive_array_data: Vec<Option<T::Native>>,
+        list_array_options: Option<SortOptions>,
+        primitive_array_options: Option<SortOptions>,
+    ) where
+        T: ArrowPrimitiveType,
+        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
+    {
+        macro_rules! run_test {
+            ($ARRAY_TYPE:ident) => {
+                let input = vec![
+                    SortColumn {
+                        values: 
Arc::new(<$ARRAY_TYPE>::from_iter_primitive::<T, _, _>(
+                            list_array_data.clone(),
+                        )) as ArrayRef,
+                        options: list_array_options.clone(),
+                    },
+                    SortColumn {
+                        values: 
Arc::new(PrimitiveArray::<T>::from(primitive_array_data.clone()))
+                            as ArrayRef,
+                        options: primitive_array_options.clone(),
+                    },
+                ];
+
+                let expected = vec![
+                    Arc::new(<$ARRAY_TYPE>::from_iter_primitive::<T, _, _>(
+                        expected_list_array_data.clone(),
+                    )) as ArrayRef,
+                    Arc::new(PrimitiveArray::<T>::from(
+                        expected_primitive_array_data.clone(),
+                    )) as ArrayRef,
+                ];
+
+                test_lex_sort_arrays(input.clone(), expected.clone(), None);
+                test_lex_sort_arrays(input.clone(), 
slice_arrays(expected.clone(), 0, 5), Some(5));
+            };
+        }
+        run_test!(ListArray);
+        run_test!(LargeListArray);
     }
 
     #[test]
diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs
index bd2db1e5022..cd952299df4 100644
--- a/arrow/benches/lexsort.rs
+++ b/arrow/benches/lexsort.rs
@@ -20,8 +20,10 @@ use arrow::row::{RowConverter, SortField};
 use arrow::util::bench_util::{
     create_dict_from_values, create_primitive_array, 
create_string_array_with_len,
 };
+use arrow::util::data_gen::create_random_array;
 use arrow_array::types::Int32Type;
 use arrow_array::{Array, ArrayRef, UInt32Array};
+use arrow_schema::{DataType, Field};
 use criterion::{criterion_group, criterion_main, Criterion};
 use std::sync::Arc;
 
@@ -33,6 +35,10 @@ enum Column {
     Optional16CharString,
     Optional50CharString,
     Optional100Value50CharStringDict,
+    RequiredI32List,
+    OptionalI32List,
+    Required4CharStringList,
+    Optional4CharStringList,
 }
 
 impl std::fmt::Debug for Column {
@@ -44,6 +50,10 @@ impl std::fmt::Debug for Column {
             Column::Optional16CharString => "str_opt(16)",
             Column::Optional50CharString => "str_opt(50)",
             Column::Optional100Value50CharStringDict => 
"dict(100,str_opt(50))",
+            Column::RequiredI32List => "i32_list",
+            Column::OptionalI32List => "i32_list_opt",
+            Column::Required4CharStringList => "str_list(4)",
+            Column::Optional4CharStringList => "str_list_opt(4)",
         };
         f.write_str(s)
     }
@@ -70,6 +80,38 @@ impl Column {
                     &create_string_array_with_len::<i32>(100, 0., 50),
                 ))
             }
+            Column::RequiredI32List => {
+                let field = Field::new(
+                    "_1",
+                    DataType::List(Arc::new(Field::new("item", 
DataType::Int32, false))),
+                    true,
+                );
+                create_random_array(&field, size, 0., 1.).unwrap()
+            }
+            Column::OptionalI32List => {
+                let field = Field::new(
+                    "_1",
+                    DataType::List(Arc::new(Field::new("item", 
DataType::Int32, true))),
+                    true,
+                );
+                create_random_array(&field, size, 0.2, 1.).unwrap()
+            }
+            Column::Required4CharStringList => {
+                let field = Field::new(
+                    "_1",
+                    DataType::List(Arc::new(Field::new("item", DataType::Utf8, 
false))),
+                    true,
+                );
+                create_random_array(&field, size, 0., 1.).unwrap()
+            }
+            Column::Optional4CharStringList => {
+                let field = Field::new(
+                    "_1",
+                    DataType::List(Arc::new(Field::new("item", DataType::Utf8, 
true))),
+                    true,
+                );
+                create_random_array(&field, size, 0.2, 1.).unwrap()
+            }
         }
     }
 }
@@ -150,6 +192,23 @@ fn add_benchmark(c: &mut Criterion) {
             Column::Optional100Value50CharStringDict,
             Column::Optional50CharString,
         ],
+        &[Column::OptionalI32, Column::RequiredI32List],
+        &[Column::OptionalI32, Column::OptionalI32List],
+        &[Column::OptionalI32List, Column::OptionalI32],
+        &[Column::RequiredI32, Column::Required4CharStringList],
+        &[Column::Required4CharStringList, Column::RequiredI32],
+        &[Column::RequiredI32, Column::Optional4CharStringList],
+        &[Column::Optional4CharStringList, Column::RequiredI32],
+        &[
+            Column::RequiredI32,
+            Column::RequiredI32List,
+            Column::Required16CharString,
+        ],
+        &[
+            Column::OptionalI32,
+            Column::OptionalI32List,
+            Column::Optional50CharString,
+        ],
     ];
 
     for case in cases {

Reply via email to