This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 9828bf0bcd3 Push SortOptions into DynComparator Allowing Nested 
Comparisons (#5426) (#5792)
9828bf0bcd3 is described below

commit 9828bf0bcd3e54ff5c51154ee99d183d9ee171fa
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue May 28 19:03:07 2024 +0100

    Push SortOptions into DynComparator Allowing Nested Comparisons (#5426) 
(#5792)
    
    * Push SortOptions into DynComparator (#5426)
    
    * Clippy
    
    * Review feedback
    
    * Tweak print_row
---
 arrow-ord/src/ord.rs   | 551 ++++++++++++++++++++++++++++++++++++++++++-------
 arrow-ord/src/sort.rs  | 114 ++--------
 arrow-row/src/lib.rs   |  53 ++++-
 arrow/src/array/mod.rs |   3 +-
 4 files changed, 542 insertions(+), 179 deletions(-)

diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs
index 8f21cd7c498..3825e5ec66f 100644
--- a/arrow-ord/src/ord.rs
+++ b/arrow-ord/src/ord.rs
@@ -20,36 +20,117 @@
 use arrow_array::cast::AsArray;
 use arrow_array::types::*;
 use arrow_array::*;
-use arrow_buffer::ArrowNativeType;
-use arrow_schema::ArrowError;
+use arrow_buffer::{ArrowNativeType, NullBuffer};
+use arrow_schema::{ArrowError, SortOptions};
 use std::cmp::Ordering;
 
 /// Compare the values at two arbitrary indices in two arrays.
 pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
 
-fn compare_primitive<T: ArrowPrimitiveType>(left: &dyn Array, right: &dyn 
Array) -> DynComparator
+/// If parent sort order is descending we need to invert the value of 
nulls_first so that
+/// when the parent is sorted based on the produced ranks, nulls are still 
ordered correctly
+fn child_opts(opts: SortOptions) -> SortOptions {
+    SortOptions {
+        descending: false,
+        nulls_first: opts.nulls_first != opts.descending,
+    }
+}
+
+fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
 where
-    T::Native: ArrowNativeTypeOp,
+    A: Array + Clone,
+    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
 {
-    let left = left.as_primitive::<T>().clone();
-    let right = right.as_primitive::<T>().clone();
-    Box::new(move |i, j| left.value(i).compare(right.value(j)))
+    let l = l.logical_nulls().filter(|x| x.null_count() > 0);
+    let r = r.logical_nulls().filter(|x| x.null_count() > 0);
+    match (opts.nulls_first, opts.descending) {
+        (true, true) => compare_impl::<true, true, _>(l, r, cmp),
+        (true, false) => compare_impl::<true, false, _>(l, r, cmp),
+        (false, true) => compare_impl::<false, true, _>(l, r, cmp),
+        (false, false) => compare_impl::<false, false, _>(l, r, cmp),
+    }
 }
 
-fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
-    let left: BooleanArray = left.as_boolean().clone();
-    let right: BooleanArray = right.as_boolean().clone();
+fn compare_impl<const NULLS_FIRST: bool, const DESCENDING: bool, F>(
+    l: Option<NullBuffer>,
+    r: Option<NullBuffer>,
+    cmp: F,
+) -> DynComparator
+where
+    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
+{
+    let cmp = move |i, j| match DESCENDING {
+        true => cmp(i, j).reverse(),
+        false => cmp(i, j),
+    };
+
+    let (left_null, right_null) = match NULLS_FIRST {
+        true => (Ordering::Less, Ordering::Greater),
+        false => (Ordering::Greater, Ordering::Less),
+    };
+
+    match (l, r) {
+        (None, None) => Box::new(cmp),
+        (Some(l), None) => Box::new(move |i, j| match l.is_null(i) {
+            true => left_null,
+            false => cmp(i, j),
+        }),
+        (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) {
+            true => right_null,
+            false => cmp(i, j),
+        }),
+        (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), 
r.is_null(j)) {
+            (true, true) => Ordering::Equal,
+            (true, false) => left_null,
+            (false, true) => right_null,
+            (false, false) => cmp(i, j),
+        }),
+    }
+}
+
+fn compare_primitive<T: ArrowPrimitiveType>(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> DynComparator
+where
+    T::Native: ArrowNativeTypeOp,
+{
+    let left = left.as_primitive::<T>();
+    let right = right.as_primitive::<T>();
+    let l_values = left.values().clone();
+    let r_values = right.values().clone();
 
-    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
+    compare(&left, &right, opts, move |i, j| {
+        l_values[i].compare(r_values[j])
+    })
 }
 
-fn compare_bytes<T: ByteArrayType>(left: &dyn Array, right: &dyn Array) -> 
DynComparator {
-    let left = left.as_bytes::<T>().clone();
-    let right = right.as_bytes::<T>().clone();
+fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> 
DynComparator {
+    let left = left.as_boolean();
+    let right = right.as_boolean();
+
+    let l_values = left.values().clone();
+    let r_values = right.values().clone();
 
-    Box::new(move |i, j| {
-        let l: &[u8] = left.value(i).as_ref();
-        let r: &[u8] = right.value(j).as_ref();
+    compare(left, right, opts, move |i, j| {
+        l_values.value(i).cmp(&r_values.value(j))
+    })
+}
+
+fn compare_bytes<T: ByteArrayType>(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> DynComparator {
+    let left = left.as_bytes::<T>();
+    let right = right.as_bytes::<T>();
+
+    let l = left.clone();
+    let r = right.clone();
+    compare(left, right, opts, move |i, j| {
+        let l: &[u8] = l.value(i).as_ref();
+        let r: &[u8] = r.value(j).as_ref();
         l.cmp(r)
     })
 }
@@ -57,67 +138,234 @@ fn compare_bytes<T: ByteArrayType>(left: &dyn Array, 
right: &dyn Array) -> DynCo
 fn compare_dict<K: ArrowDictionaryKeyType>(
     left: &dyn Array,
     right: &dyn Array,
+    opts: SortOptions,
 ) -> Result<DynComparator, ArrowError> {
     let left = left.as_dictionary::<K>();
     let right = right.as_dictionary::<K>();
 
-    let cmp = build_compare(left.values().as_ref(), right.values().as_ref())?;
-    let left_keys = left.keys().clone();
-    let right_keys = right.keys().clone();
+    let c_opts = child_opts(opts);
+    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), 
c_opts)?;
+    let left_keys = left.keys().values().clone();
+    let right_keys = right.keys().values().clone();
 
-    // TODO: Handle value nulls (#2687)
-    Ok(Box::new(move |i, j| {
-        let l = left_keys.value(i).as_usize();
-        let r = right_keys.value(j).as_usize();
+    let f = compare(left, right, opts, move |i, j| {
+        let l = left_keys[i].as_usize();
+        let r = right_keys[j].as_usize();
         cmp(l, r)
-    }))
+    });
+    Ok(f)
 }
 
-/// returns a comparison function that compares two values at two different 
positions
+fn compare_list<O: OffsetSizeTrait>(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_list::<O>();
+    let right = right.as_list::<O>();
+
+    let c_opts = child_opts(opts);
+    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), 
c_opts)?;
+
+    let l_o = left.offsets().clone();
+    let r_o = right.offsets().clone();
+    let f = compare(left, right, opts, move |i, j| {
+        let l_end = l_o[i + 1].as_usize();
+        let l_start = l_o[i].as_usize();
+
+        let r_end = r_o[j + 1].as_usize();
+        let r_start = r_o[j].as_usize();
+
+        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
+            match cmp(i, j) {
+                Ordering::Equal => continue,
+                r => return r,
+            }
+        }
+        (l_end - l_start).cmp(&(r_end - r_start))
+    });
+    Ok(f)
+}
+
+fn compare_fixed_list(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_fixed_size_list();
+    let right = right.as_fixed_size_list();
+
+    let c_opts = child_opts(opts);
+    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), 
c_opts)?;
+
+    let l_size = left.value_length().to_usize().unwrap();
+    let r_size = right.value_length().to_usize().unwrap();
+    let size_cmp = l_size.cmp(&r_size);
+
+    let f = compare(left, right, opts, move |i, j| {
+        let l_start = i * l_size;
+        let l_end = l_start + l_size;
+        let r_start = j * r_size;
+        let r_end = r_start + r_size;
+        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
+            match cmp(i, j) {
+                Ordering::Equal => continue,
+                r => return r,
+            }
+        }
+        size_cmp
+    });
+    Ok(f)
+}
+
+fn compare_struct(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_struct();
+    let right = right.as_struct();
+
+    if left.columns().len() != right.columns().len() {
+        return Err(ArrowError::InvalidArgumentError(
+            "Cannot compare StructArray with different number of 
columns".to_string(),
+        ));
+    }
+
+    let c_opts = child_opts(opts);
+    let columns = left.columns().iter().zip(right.columns());
+    let comparators = columns
+        .map(|(l, r)| make_comparator(l, r, c_opts))
+        .collect::<Result<Vec<_>, _>>()?;
+
+    let f = compare(left, right, opts, move |i, j| {
+        for cmp in &comparators {
+            match cmp(i, j) {
+                Ordering::Equal => continue,
+                r => return r,
+            }
+        }
+        Ordering::Equal
+    });
+    Ok(f)
+}
+
+#[deprecated(note = "Use make_comparator")]
+#[doc(hidden)]
+pub fn build_compare(left: &dyn Array, right: &dyn Array) -> 
Result<DynComparator, ArrowError> {
+    make_comparator(left, right, SortOptions::default())
+}
+
+/// Returns a comparison function that compares two values at two different 
positions
 /// between the two arrays.
-/// The arrays' types must be equal.
-/// # Example
-/// ```
-/// use arrow_array::Int32Array;
-/// use arrow_ord::ord::build_compare;
 ///
+/// For comparing arrays element-wise, see also the vectorised kernels in 
[`crate::cmp`].
+///
+/// If `nulls_first` is true `NULL` values will be considered less than any 
non-null value,
+/// otherwise they will be considered greater.
+///
+/// # Basic Usage
+///
+/// ```
+/// # use std::cmp::Ordering;
+/// # use arrow_array::Int32Array;
+/// # use arrow_ord::ord::make_comparator;
+/// # use arrow_schema::SortOptions;
+/// #
 /// let array1 = Int32Array::from(vec![1, 2]);
 /// let array2 = Int32Array::from(vec![3, 4]);
 ///
-/// let cmp = build_compare(&array1, &array2).unwrap();
-///
+/// let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
-/// assert_eq!(std::cmp::Ordering::Less, cmp(0, 1));
+/// assert_eq!(cmp(0, 1), Ordering::Less);
+///
+/// let array1 = Int32Array::from(vec![Some(1), None]);
+/// let array2 = Int32Array::from(vec![None, Some(2)]);
+/// let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
+///
+/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2)
+/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2)
+/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None
+/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None
 /// ```
-// This is a factory of comparisons.
-// The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
-pub fn build_compare(left: &dyn Array, right: &dyn Array) -> 
Result<DynComparator, ArrowError> {
+///
+/// # Postgres-compatible Nested Comparison
+///
+/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value 
against a NULL yields
+/// a NULL, many systems, including postgres, instead apply a total ordering 
to comparison of
+/// nested nulls. That is nulls within nested types are either greater than 
any value (postgres),
+/// or less than any value (Spark).
+///
+/// In particular
+///
+/// ```ignore
+/// { a: 1, b: null } == { a: 1, b: null } => true
+/// { a: 1, b: null } == { a: 1, b: 1 } => false
+/// { a: 1, b: null } == null => null
+/// null == null => null
+/// ```
+///
+/// This could be implemented as below
+///
+/// ```
+/// # use arrow_array::{Array, BooleanArray};
+/// # use arrow_buffer::NullBuffer;
+/// # use arrow_ord::cmp;
+/// # use arrow_ord::ord::make_comparator;
+/// # use arrow_schema::{ArrowError, SortOptions};
+/// fn eq(a: &dyn Array, b: &dyn Array) -> Result<BooleanArray, ArrowError> {
+///     if !a.data_type().is_nested() {
+///         return cmp::eq(&a, &b); // Use faster vectorised kernel
+///     }
+///
+///     let cmp = make_comparator(a, b, SortOptions::default())?;
+///     let len = a.len().min(b.len());
+///     let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
+///     let nulls = NullBuffer::union(a.nulls(), b.nulls());
+///     Ok(BooleanArray::new(values, nulls))
+/// }
+/// ````
+pub fn make_comparator(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
     use arrow_schema::DataType::*;
+
     macro_rules! primitive_helper {
-        ($t:ty, $left:expr, $right:expr) => {
-            Ok(compare_primitive::<$t>($left, $right))
+        ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => {
+            Ok(compare_primitive::<$t>($left, $right, $nulls_first))
         };
     }
     downcast_primitive! {
-        left.data_type(), right.data_type() => (primitive_helper, left, right),
-        (Boolean, Boolean) => Ok(compare_boolean(left, right)),
-        (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right)),
-        (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left, 
right)),
-        (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right)),
-        (LargeBinary, LargeBinary) => 
Ok(compare_bytes::<LargeBinaryType>(left, right)),
+        left.data_type(), right.data_type() => (primitive_helper, left, right, 
opts),
+        (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)),
+        (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right, opts)),
+        (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left, 
right, opts)),
+        (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right, opts)),
+        (LargeBinary, LargeBinary) => 
Ok(compare_bytes::<LargeBinaryType>(left, right, opts)),
         (FixedSizeBinary(_), FixedSizeBinary(_)) => {
-            let left = left.as_fixed_size_binary().clone();
-            let right = right.as_fixed_size_binary().clone();
-            Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j))))
+            let left = left.as_fixed_size_binary();
+            let right = right.as_fixed_size_binary();
+
+            let l = left.clone();
+            let r = right.clone();
+            Ok(compare(left, right, opts, move |i, j| {
+                l.value(i).cmp(r.value(j))
+            }))
         },
+        (List(_), List(_)) => compare_list::<i32>(left, right, opts),
+        (LargeList(_), LargeList(_)) => compare_list::<i64>(left, right, opts),
+        (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left, 
right, opts),
+        (Struct(_), Struct(_)) => compare_struct(left, right, opts),
         (Dictionary(l_key, _), Dictionary(r_key, _)) => {
              macro_rules! dict_helper {
-                ($t:ty, $left:expr, $right:expr) => {
-                     compare_dict::<$t>($left, $right)
+                ($t:ty, $left:expr, $right:expr, $opts: expr) => {
+                     compare_dict::<$t>($left, $right, $opts)
                  };
              }
             downcast_integer! {
-                 l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right),
+                 l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right, 
opts),
                  _ => unreachable!()
              }
         },
@@ -131,7 +379,9 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) 
-> Result<DynComparato
 #[cfg(test)]
 pub mod tests {
     use super::*;
+    use arrow_array::builder::{Int32Builder, ListBuilder};
     use arrow_buffer::{i256, IntervalDayTime, OffsetBuffer};
+    use arrow_schema::{DataType, Field, Fields};
     use half::f16;
     use std::sync::Arc;
 
@@ -140,7 +390,7 @@ pub mod tests {
         let items = vec![vec![1u8], vec![2u8]];
         let array = 
FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
     }
@@ -152,7 +402,7 @@ pub mod tests {
         let items = vec![vec![2u8]];
         let array2 = 
FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
     }
@@ -161,7 +411,7 @@ pub mod tests {
     fn test_i32() {
         let array = Int32Array::from(vec![1, 2]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, (cmp)(0, 1));
     }
@@ -171,7 +421,7 @@ pub mod tests {
         let array1 = Int32Array::from(vec![1]);
         let array2 = Int32Array::from(vec![2]);
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
     }
@@ -180,7 +430,7 @@ pub mod tests {
     fn test_f16() {
         let array = Float16Array::from(vec![f16::from_f32(1.0), 
f16::from_f32(2.0)]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
     }
@@ -189,7 +439,7 @@ pub mod tests {
     fn test_f64() {
         let array = Float64Array::from(vec![1.0, 2.0]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
     }
@@ -198,7 +448,7 @@ pub mod tests {
     fn test_f64_nan() {
         let array = Float64Array::from(vec![1.0, f64::NAN]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
         assert_eq!(Ordering::Equal, cmp(1, 1));
@@ -208,7 +458,7 @@ pub mod tests {
     fn test_f64_zeros() {
         let array = Float64Array::from(vec![-0.0, 0.0]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
         assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -225,7 +475,7 @@ pub mod tests {
             IntervalDayTimeType::make_value(0, 90_000_000),
         ]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
         assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -248,7 +498,7 @@ pub mod tests {
             IntervalYearMonthType::make_value(1, 1),
         ]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
         assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -269,7 +519,7 @@ pub mod tests {
             IntervalMonthDayNanoType::make_value(0, 100, 2),
         ]);
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
         assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -289,7 +539,7 @@ pub mod tests {
             .with_precision_and_scale(23, 6)
             .unwrap();
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
         assert_eq!(Ordering::Less, cmp(1, 0));
         assert_eq!(Ordering::Greater, cmp(0, 2));
     }
@@ -306,7 +556,7 @@ pub mod tests {
         .with_precision_and_scale(53, 6)
         .unwrap();
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
         assert_eq!(Ordering::Less, cmp(1, 0));
         assert_eq!(Ordering::Greater, cmp(0, 2));
     }
@@ -316,7 +566,7 @@ pub mod tests {
         let data = vec!["a", "b", "c", "a", "a", "c", "c"];
         let array = data.into_iter().collect::<DictionaryArray<Int16Type>>();
 
-        let cmp = build_compare(&array, &array).unwrap();
+        let cmp = make_comparator(&array, &array, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
         assert_eq!(Ordering::Equal, cmp(3, 4));
@@ -330,7 +580,7 @@ pub mod tests {
         let d2 = vec!["e", "f", "g", "a"];
         let a2 = d2.into_iter().collect::<DictionaryArray<Int16Type>>();
 
-        let cmp = build_compare(&a1, &a2).unwrap();
+        let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
         assert_eq!(Ordering::Equal, cmp(0, 3));
@@ -347,7 +597,7 @@ pub mod tests {
         let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
         let array2 = DictionaryArray::new(keys, Arc::new(values));
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
         assert_eq!(Ordering::Less, cmp(0, 3));
@@ -366,7 +616,7 @@ pub mod tests {
         let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
         let array2 = DictionaryArray::new(keys, Arc::new(values));
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
         assert_eq!(Ordering::Less, cmp(0, 3));
@@ -385,7 +635,7 @@ pub mod tests {
         let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
         let array2 = DictionaryArray::new(keys, Arc::new(values));
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
         assert_eq!(Ordering::Less, cmp(0, 3));
@@ -408,7 +658,7 @@ pub mod tests {
         let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
         let array2 = DictionaryArray::new(keys, Arc::new(values));
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0)); // v1 vs v3
         assert_eq!(Ordering::Equal, cmp(0, 3)); // v1 vs v1
@@ -427,7 +677,7 @@ pub mod tests {
         let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
         let array2 = DictionaryArray::new(keys, Arc::new(values));
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
         assert_eq!(Ordering::Less, cmp(0, 3));
@@ -446,7 +696,7 @@ pub mod tests {
         let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
         let array2 = DictionaryArray::new(keys, Arc::new(values));
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
         assert_eq!(Ordering::Less, cmp(0, 3));
@@ -475,7 +725,7 @@ pub mod tests {
         let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
         let array2 = DictionaryArray::new(keys, Arc::new(values));
 
-        let cmp = build_compare(&array1, &array2).unwrap();
+        let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 0));
         assert_eq!(Ordering::Less, cmp(0, 3));
@@ -487,7 +737,7 @@ pub mod tests {
     fn test_bytes_impl<T: ByteArrayType>() {
         let offsets = OffsetBuffer::from_lengths([3, 3, 1]);
         let a = GenericByteArray::<T>::new(offsets, b"abcdefa".into(), None);
-        let cmp = build_compare(&a, &a).unwrap();
+        let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap();
 
         assert_eq!(Ordering::Less, cmp(0, 1));
         assert_eq!(Ordering::Greater, cmp(0, 2));
@@ -501,4 +751,157 @@ pub mod tests {
         test_bytes_impl::<BinaryType>();
         test_bytes_impl::<LargeBinaryType>();
     }
+
+    #[test]
+    fn test_lists() {
+        let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
+        a.extend([
+            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
+            Some(vec![
+                Some(vec![Some(1), Some(2), Some(3)]),
+                Some(vec![Some(1)]),
+            ]),
+            Some(vec![]),
+        ]);
+        let a = a.finish();
+        let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
+        b.extend([
+            Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
+            Some(vec![
+                Some(vec![Some(1), Some(2), None]),
+                Some(vec![Some(1)]),
+            ]),
+            Some(vec![
+                Some(vec![Some(1), Some(2), Some(3), Some(4)]),
+                Some(vec![Some(1)]),
+            ]),
+            None,
+        ]);
+        let b = b.finish();
+
+        let opts = SortOptions {
+            descending: false,
+            nulls_first: true,
+        };
+        let cmp = make_comparator(&a, &b, opts).unwrap();
+        assert_eq!(cmp(0, 0), Ordering::Equal);
+        assert_eq!(cmp(0, 1), Ordering::Less);
+        assert_eq!(cmp(0, 2), Ordering::Less);
+        assert_eq!(cmp(1, 2), Ordering::Less);
+        assert_eq!(cmp(1, 3), Ordering::Greater);
+        assert_eq!(cmp(2, 0), Ordering::Less);
+
+        let opts = SortOptions {
+            descending: true,
+            nulls_first: true,
+        };
+        let cmp = make_comparator(&a, &b, opts).unwrap();
+        assert_eq!(cmp(0, 0), Ordering::Equal);
+        assert_eq!(cmp(0, 1), Ordering::Less);
+        assert_eq!(cmp(0, 2), Ordering::Less);
+        assert_eq!(cmp(1, 2), Ordering::Greater);
+        assert_eq!(cmp(1, 3), Ordering::Greater);
+        assert_eq!(cmp(2, 0), Ordering::Greater);
+
+        let opts = SortOptions {
+            descending: true,
+            nulls_first: false,
+        };
+        let cmp = make_comparator(&a, &b, opts).unwrap();
+        assert_eq!(cmp(0, 0), Ordering::Equal);
+        assert_eq!(cmp(0, 1), Ordering::Greater);
+        assert_eq!(cmp(0, 2), Ordering::Greater);
+        assert_eq!(cmp(1, 2), Ordering::Greater);
+        assert_eq!(cmp(1, 3), Ordering::Less);
+        assert_eq!(cmp(2, 0), Ordering::Greater);
+
+        let opts = SortOptions {
+            descending: false,
+            nulls_first: false,
+        };
+        let cmp = make_comparator(&a, &b, opts).unwrap();
+        assert_eq!(cmp(0, 0), Ordering::Equal);
+        assert_eq!(cmp(0, 1), Ordering::Greater);
+        assert_eq!(cmp(0, 2), Ordering::Greater);
+        assert_eq!(cmp(1, 2), Ordering::Less);
+        assert_eq!(cmp(1, 3), Ordering::Less);
+        assert_eq!(cmp(2, 0), Ordering::Less);
+    }
+
+    #[test]
+    fn test_struct() {
+        let fields = Fields::from(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new_list("b", Field::new("item", DataType::Int32, true), 
true),
+        ]);
+
+        let a = Int32Array::from(vec![Some(1), Some(2), None, None]);
+        let mut b = ListBuilder::new(Int32Builder::new());
+        b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]);
+        let b = b.finish();
+
+        let nulls = Some(NullBuffer::from_iter([true, true, true, false]));
+        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
+        let s1 = StructArray::new(fields.clone(), values, nulls);
+
+        let a = Int32Array::from(vec![None, Some(2), None]);
+        let mut b = ListBuilder::new(Int32Builder::new());
+        b.extend([None, None, Some(vec![])]);
+        let b = b.finish();
+
+        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
+        let s2 = StructArray::new(fields.clone(), values, None);
+
+        let opts = SortOptions {
+            descending: false,
+            nulls_first: true,
+        };
+        let cmp = make_comparator(&s1, &s2, opts).unwrap();
+        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
+        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, 
None)
+        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
+        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
+        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
+        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, 
None)
+        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
+
+        let opts = SortOptions {
+            descending: true,
+            nulls_first: true,
+        };
+        let cmp = make_comparator(&s1, &s2, opts).unwrap();
+        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
+        assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, 
None)
+        assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
+        assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
+        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
+        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, 
None)
+        assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
+
+        let opts = SortOptions {
+            descending: true,
+            nulls_first: false,
+        };
+        let cmp = make_comparator(&s1, &s2, opts).unwrap();
+        assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
+        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
+        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
+        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, 
[])
+        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
+        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, 
None)
+        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
+
+        let opts = SortOptions {
+            descending: false,
+            nulls_first: false,
+        };
+        let cmp = make_comparator(&s1, &s2, opts).unwrap();
+        assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
+        assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
+        assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
+        assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, 
[])
+        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
+        assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, 
None)
+        assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
+    }
 }
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index fe3a1f86ac0..8ae87787d28 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -17,13 +17,13 @@
 
 //! Defines sort kernel for `ArrayRef`
 
-use crate::ord::{build_compare, DynComparator};
+use crate::ord::{make_comparator, DynComparator};
 use arrow_array::builder::BufferBuilder;
 use arrow_array::cast::*;
 use arrow_array::types::*;
 use arrow_array::*;
+use arrow_buffer::ArrowNativeType;
 use arrow_buffer::BooleanBufferBuilder;
-use arrow_buffer::{ArrowNativeType, NullBuffer};
 use arrow_data::ArrayDataBuilder;
 use arrow_schema::{ArrowError, DataType};
 use arrow_select::take::take;
@@ -704,60 +704,21 @@ where
     }
 }
 
-type LexicographicalCompareItem = (
-    Option<NullBuffer>, // nulls
-    DynComparator,      // comparator
-    SortOptions,        // sort_option
-);
-
 /// A lexicographical comparator that wraps given array data (columns) and can 
lexicographically compare data
 /// at given two indices. The lifetime is the same at the data wrapped.
 pub struct LexicographicalComparator {
-    compare_items: Vec<LexicographicalCompareItem>,
+    compare_items: Vec<DynComparator>,
 }
 
 impl LexicographicalComparator {
     /// lexicographically compare values at the wrapped columns with given 
indices.
     pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering {
-        for (nulls, comparator, sort_option) in &self.compare_items {
-            let (lhs_valid, rhs_valid) = match nulls {
-                Some(n) => (n.is_valid(a_idx), n.is_valid(b_idx)),
-                None => (true, true),
-            };
-
-            match (lhs_valid, rhs_valid) {
-                (true, true) => {
-                    match (comparator)(a_idx, b_idx) {
-                        // equal, move on to next column
-                        Ordering::Equal => continue,
-                        order => {
-                            if sort_option.descending {
-                                return order.reverse();
-                            } else {
-                                return order;
-                            }
-                        }
-                    }
-                }
-                (false, true) => {
-                    return if sort_option.nulls_first {
-                        Ordering::Less
-                    } else {
-                        Ordering::Greater
-                    };
-                }
-                (true, false) => {
-                    return if sort_option.nulls_first {
-                        Ordering::Greater
-                    } else {
-                        Ordering::Less
-                    };
-                }
-                // equal, move on to next column
-                (false, false) => continue,
+        for comparator in &self.compare_items {
+            match comparator(a_idx, b_idx) {
+                Ordering::Equal => continue,
+                r => return r,
             }
         }
-
         Ordering::Equal
     }
 
@@ -766,61 +727,16 @@ impl LexicographicalComparator {
     pub fn try_new(columns: &[SortColumn]) -> 
Result<LexicographicalComparator, ArrowError> {
         let compare_items = columns
             .iter()
-            .map(Self::build_compare_item)
+            .map(|c| {
+                make_comparator(
+                    c.values.as_ref(),
+                    c.values.as_ref(),
+                    c.options.unwrap_or_default(),
+                )
+            })
             .collect::<Result<Vec<_>, ArrowError>>()?;
         Ok(LexicographicalComparator { compare_items })
     }
-
-    fn build_compare_item(column: &SortColumn) -> 
Result<LexicographicalCompareItem, ArrowError> {
-        let values = column.values.as_ref();
-        let options = column.options.unwrap_or_default();
-        let comparator = match values.data_type() {
-            DataType::List(_) => 
Self::build_list_compare(values.as_list::<i32>(), options)?,
-            DataType::LargeList(_) => 
Self::build_list_compare(values.as_list::<i64>(), options)?,
-            DataType::FixedSizeList(_, _) => {
-                
Self::build_fixed_size_list_compare(values.as_fixed_size_list(), options)?
-            }
-            _ => build_compare(values, values)?,
-        };
-        Ok((values.logical_nulls(), comparator, options))
-    }
-
-    fn build_list_compare<O: OffsetSizeTrait>(
-        array: &GenericListArray<O>,
-        options: SortOptions,
-    ) -> Result<DynComparator, ArrowError> {
-        let rank = child_rank(array.values().as_ref(), options)?;
-        let offsets = array.offsets().clone();
-        let cmp = Box::new(move |i: usize, j: usize| {
-            macro_rules! nth_value {
-                ($INDEX:expr) => {{
-                    let end = offsets[$INDEX + 1].as_usize();
-                    let start = offsets[$INDEX].as_usize();
-                    &rank[start..end]
-                }};
-            }
-            Ord::cmp(nth_value!(i), nth_value!(j))
-        });
-        Ok(cmp)
-    }
-
-    fn build_fixed_size_list_compare(
-        array: &FixedSizeListArray,
-        options: SortOptions,
-    ) -> Result<DynComparator, ArrowError> {
-        let rank = child_rank(array.values().as_ref(), options)?;
-        let size = array.value_length() as usize;
-        let cmp = Box::new(move |i: usize, j: usize| {
-            macro_rules! nth_value {
-                ($INDEX:expr) => {{
-                    let start = $INDEX * size;
-                    &rank[start..start + size]
-                }};
-            }
-            Ord::cmp(nth_value!(i), nth_value!(j))
-        });
-        Ok(cmp)
-    }
 }
 
 #[cfg(test)]
@@ -829,7 +745,7 @@ mod tests {
     use arrow_array::builder::{
         FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder,
     };
-    use arrow_buffer::i256;
+    use arrow_buffer::{i256, NullBuffer};
     use half::f16;
     use rand::rngs::StdRng;
     use rand::{Rng, RngCore, SeedableRng};
diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs
index 4dc7349ca2d..8e1285493b0 100644
--- a/arrow-row/src/lib.rs
+++ b/arrow-row/src/lib.rs
@@ -1302,9 +1302,9 @@ mod tests {
     use arrow_array::builder::*;
     use arrow_array::types::*;
     use arrow_array::*;
-    use arrow_buffer::i256;
-    use arrow_buffer::Buffer;
-    use arrow_cast::display::array_value_to_string;
+    use arrow_buffer::{i256, NullBuffer};
+    use arrow_buffer::{Buffer, OffsetBuffer};
+    use arrow_cast::display::{ArrayFormatter, FormatOptions};
     use arrow_ord::sort::{LexicographicalComparator, SortColumn};
 
     use super::*;
@@ -2099,9 +2099,35 @@ mod tests {
         builder.finish()
     }
 
+    fn generate_struct(len: usize, valid_percent: f64) -> StructArray {
+        let mut rng = thread_rng();
+        let nulls = NullBuffer::from_iter((0..len).map(|_| 
rng.gen_bool(valid_percent)));
+        let a = generate_primitive_array::<Int32Type>(len, valid_percent);
+        let b = generate_strings::<i32>(len, valid_percent);
+        let fields = Fields::from(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("b", DataType::Utf8, true),
+        ]);
+        let values = vec![Arc::new(a) as _, Arc::new(b) as _];
+        StructArray::new(fields, values, Some(nulls))
+    }
+
+    fn generate_list<F>(len: usize, valid_percent: f64, values: F) -> ListArray
+    where
+        F: FnOnce(usize) -> ArrayRef,
+    {
+        let mut rng = thread_rng();
+        let offsets = OffsetBuffer::<i32>::from_lengths((0..len).map(|_| 
rng.gen_range(0..10)));
+        let values_len = offsets.last().unwrap().to_usize().unwrap();
+        let values = values(values_len);
+        let nulls = NullBuffer::from_iter((0..len).map(|_| 
rng.gen_bool(valid_percent)));
+        let field = Arc::new(Field::new("item", values.data_type().clone(), 
true));
+        ListArray::new(field, offsets, values, Some(nulls))
+    }
+
     fn generate_column(len: usize) -> ArrayRef {
         let mut rng = thread_rng();
-        match rng.gen_range(0..10) {
+        match rng.gen_range(0..14) {
             0 => Arc::new(generate_primitive_array::<Int32Type>(len, 0.8)),
             1 => Arc::new(generate_primitive_array::<UInt32Type>(len, 0.8)),
             2 => Arc::new(generate_primitive_array::<Int64Type>(len, 0.8)),
@@ -2125,6 +2151,16 @@ mod tests {
                 0.8,
             )),
             9 => Arc::new(generate_fixed_size_binary(len, 0.8)),
+            10 => Arc::new(generate_struct(len, 0.8)),
+            11 => Arc::new(generate_list(len, 0.8, |values_len| {
+                Arc::new(generate_primitive_array::<Int64Type>(values_len, 
0.8))
+            })),
+            12 => Arc::new(generate_list(len, 0.8, |values_len| {
+                Arc::new(generate_strings::<i32>(values_len, 0.8))
+            })),
+            13 => Arc::new(generate_list(len, 0.8, |values_len| {
+                Arc::new(generate_struct(values_len, 0.8))
+            })),
             _ => unreachable!(),
         }
     }
@@ -2132,7 +2168,14 @@ mod tests {
     fn print_row(cols: &[SortColumn], row: usize) -> String {
         let t: Vec<_> = cols
             .iter()
-            .map(|x| array_value_to_string(&x.values, row).unwrap())
+            .map(|x| match x.values.is_valid(row) {
+                true => {
+                    let opts = FormatOptions::default().with_null("NULL");
+                    let formatter = ArrayFormatter::try_new(x.values.as_ref(), 
&opts).unwrap();
+                    formatter.value(row).to_string()
+                }
+                false => "NULL".to_string(),
+            })
             .collect();
         t.join(",")
     }
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index b563c320bb6..242c9148cac 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -36,4 +36,5 @@ pub use arrow_array::ffi::export_array_into_raw;
 
 // --------------------- Array's values comparison ---------------------
 
-pub use arrow_ord::ord::{build_compare, DynComparator};
+#[allow(deprecated)]
+pub use arrow_ord::ord::{build_compare, make_comparator, DynComparator};


Reply via email to