tustvold commented on code in PR #5792:
URL: https://github.com/apache/arrow-rs/pull/5792#discussion_r1608640553


##########
arrow-ord/src/ord.rs:
##########
@@ -20,104 +20,338 @@
 use arrow_array::cast::AsArray;
 use arrow_array::types::*;
 use arrow_array::*;
-use arrow_buffer::ArrowNativeType;
-use arrow_schema::ArrowError;
+use arrow_buffer::{ArrowNativeType, NullBuffer};
+use arrow_schema::{ArrowError, SortOptions};
 use std::cmp::Ordering;
 
 /// Compare the values at two arbitrary indices in two arrays.
 pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
 
-fn compare_primitive<T: ArrowPrimitiveType>(left: &dyn Array, right: &dyn 
Array) -> DynComparator
+/// If parent sort order is descending we need to invert the value of 
nulls_first so that
+/// when the parent is sorted based on the produced ranks, nulls are still 
ordered correctly
+fn child_opts(opts: SortOptions) -> SortOptions {
+    SortOptions {
+        descending: false,
+        nulls_first: opts.nulls_first != opts.descending,
+    }
+}
+
+fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
 where
-    T::Native: ArrowNativeTypeOp,
+    A: Array + Clone,
+    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
 {
-    let left = left.as_primitive::<T>().clone();
-    let right = right.as_primitive::<T>().clone();
-    Box::new(move |i, j| left.value(i).compare(right.value(j)))
+    let l = l.logical_nulls().filter(|x| x.null_count() > 0);
+    let r = r.logical_nulls().filter(|x| x.null_count() > 0);
+    match (opts.nulls_first, opts.descending) {
+        (true, true) => compare_impl::<true, true, _>(l, r, cmp),
+        (true, false) => compare_impl::<true, false, _>(l, r, cmp),
+        (false, true) => compare_impl::<false, true, _>(l, r, cmp),
+        (false, false) => compare_impl::<false, false, _>(l, r, cmp),
+    }
 }
 
-fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
-    let left: BooleanArray = left.as_boolean().clone();
-    let right: BooleanArray = right.as_boolean().clone();
+fn compare_impl<const NULLS_FIRST: bool, const DESCENDING: bool, F>(
+    l: Option<NullBuffer>,
+    r: Option<NullBuffer>,
+    cmp: F,
+) -> DynComparator
+where
+    F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
+{
+    let cmp = move |i, j| match DESCENDING {
+        true => cmp(i, j).reverse(),
+        false => cmp(i, j),
+    };
+
+    let (left_null, right_null) = match NULLS_FIRST {
+        true => (Ordering::Less, Ordering::Greater),
+        false => (Ordering::Greater, Ordering::Less),
+    };
+
+    match (l, r) {
+        (None, None) => Box::new(cmp),
+        (Some(l), None) => Box::new(move |i, j| match l.is_null(i) {
+            true => left_null,
+            false => cmp(i, j),
+        }),
+        (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) {
+            true => right_null,
+            false => cmp(i, j),
+        }),
+        (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), 
r.is_null(j)) {
+            (true, true) => Ordering::Equal,
+            (true, false) => left_null,
+            (false, true) => right_null,
+            (false, false) => cmp(i, j),
+        }),
+    }
+}
+
+fn compare_primitive<T: ArrowPrimitiveType>(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> DynComparator
+where
+    T::Native: ArrowNativeTypeOp,
+{
+    let left = left.as_primitive::<T>();
+    let right = right.as_primitive::<T>();
+    let l_values = left.values().clone();
+    let r_values = right.values().clone();
 
-    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
+    compare(&left, &right, opts, move |i, j| {
+        l_values[i].compare(r_values[j])
+    })
 }
 
-fn compare_bytes<T: ByteArrayType>(left: &dyn Array, right: &dyn Array) -> 
DynComparator {
-    let left = left.as_bytes::<T>().clone();
-    let right = right.as_bytes::<T>().clone();
+fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> 
DynComparator {
+    let left = left.as_boolean();
+    let right = right.as_boolean();
+
+    let l_values = left.values().clone();
+    let r_values = right.values().clone();
+
+    compare(left, right, opts, move |i, j| {
+        l_values.value(i).cmp(&r_values.value(j))
+    })
+}
 
-    Box::new(move |i, j| {
-        let l: &[u8] = left.value(i).as_ref();
-        let r: &[u8] = right.value(j).as_ref();
+fn compare_bytes<T: ByteArrayType>(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> DynComparator {
+    let left = left.as_bytes::<T>();
+    let right = right.as_bytes::<T>();
+
+    let l = left.clone();
+    let r = right.clone();
+    compare(left, right, opts, move |i, j| {
+        let l: &[u8] = l.value(i).as_ref();
+        let r: &[u8] = r.value(j).as_ref();
         l.cmp(r)
     })
 }
 
 fn compare_dict<K: ArrowDictionaryKeyType>(
     left: &dyn Array,
     right: &dyn Array,
+    opts: SortOptions,
 ) -> Result<DynComparator, ArrowError> {
     let left = left.as_dictionary::<K>();
     let right = right.as_dictionary::<K>();
 
-    let cmp = build_compare(left.values().as_ref(), right.values().as_ref())?;
-    let left_keys = left.keys().clone();
-    let right_keys = right.keys().clone();
+    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), 
opts)?;
+    let left_keys = left.keys().values().clone();
+    let right_keys = right.keys().values().clone();
 
-    // TODO: Handle value nulls (#2687)
-    Ok(Box::new(move |i, j| {
-        let l = left_keys.value(i).as_usize();
-        let r = right_keys.value(j).as_usize();
+    let f = compare(left, right, opts, move |i, j| {
+        let l = left_keys[i].as_usize();
+        let r = right_keys[j].as_usize();
         cmp(l, r)
-    }))
+    });
+    Ok(f)
+}
+
+fn compare_list<O: OffsetSizeTrait>(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_list::<O>();
+    let right = right.as_list::<O>();
+
+    let c_opts = child_opts(opts);
+    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), 
c_opts)?;
+
+    let l_o = left.offsets().clone();
+    let r_o = right.offsets().clone();
+    let f = compare(left, right, opts, move |i, j| {
+        let l_end = l_o[i + 1].as_usize();
+        let l_start = l_o[i].as_usize();
+
+        let r_end = r_o[j + 1].as_usize();
+        let r_start = r_o[j].as_usize();
+
+        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
+            match cmp(i, j) {
+                Ordering::Equal => continue,
+                r => return r,
+            }
+        }
+        (l_end - l_start).cmp(&(r_end - r_start))
+    });
+    Ok(f)
+}
+
+fn compare_fixed_list(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_fixed_size_list();
+    let right = right.as_fixed_size_list();
+
+    let c_opts = child_opts(opts);
+    let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), 
c_opts)?;
+
+    let l_size = left.value_length().to_usize().unwrap();
+    let r_size = right.value_length().to_usize().unwrap();
+    let size_cmp = l_size.cmp(&r_size);
+
+    let f = compare(left, right, opts, move |i, j| {
+        let l_start = i * l_size;
+        let l_end = l_start + l_size;
+        let r_start = j * r_size;
+        let r_end = r_start + r_size;
+        for (i, j) in (l_start..l_end).zip(r_start..r_end) {
+            match cmp(i, j) {
+                Ordering::Equal => continue,
+                r => return r,
+            }
+        }
+        size_cmp
+    });
+    Ok(f)
 }
 
-/// returns a comparison function that compares two values at two different 
positions
+fn compare_struct(
+    left: &dyn Array,
+    right: &dyn Array,
+    opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_struct();
+    let right = right.as_struct();
+
+    if left.columns().len() != right.columns().len() {
+        return Err(ArrowError::InvalidArgumentError(
+            "Cannot compare StructArray with different number of 
columns".to_string(),
+        ));
+    }
+
+    let c_opts = child_opts(opts);
+    let columns = left.columns().iter().zip(right.columns());
+    let comparators = columns
+        .map(|(l, r)| make_comparator(l, r, c_opts))
+        .collect::<Result<Vec<_>, _>>()?;
+
+    let f = compare(left, right, opts, move |i, j| {
+        for cmp in &comparators {
+            match cmp(i, j) {
+                Ordering::Equal => continue,
+                r => return r,
+            }
+        }
+        Ordering::Equal
+    });
+    Ok(f)
+}
+
+#[deprecated(note = "Use make_comparator")]
+#[doc(hidden)]
+pub fn build_compare(left: &dyn Array, right: &dyn Array) -> 
Result<DynComparator, ArrowError> {
+    make_comparator(left, right, SortOptions::default())
+}
+
+/// Returns a comparison function that compares two values at two different 
positions
 /// between the two arrays.
-/// The arrays' types must be equal.
-/// # Example
-/// ```
-/// use arrow_array::Int32Array;
-/// use arrow_ord::ord::build_compare;
 ///
+/// If `nulls_first` is true `NULL` values will be considered less than any 
non-null value,
+/// otherwise they will be considered greater.
+///
+/// # Basic Usage
+///
+/// ```
+/// # use std::cmp::Ordering;
+/// # use arrow_array::Int32Array;
+/// # use arrow_ord::ord::make_comparator;
+/// # use arrow_schema::SortOptions;
+/// #
 /// let array1 = Int32Array::from(vec![1, 2]);
 /// let array2 = Int32Array::from(vec![3, 4]);
 ///
-/// let cmp = build_compare(&array1, &array2).unwrap();
-///
+/// let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
 /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
-/// assert_eq!(std::cmp::Ordering::Less, cmp(0, 1));
+/// assert_eq!(cmp(0, 1), Ordering::Less);
+///
+/// let array1 = Int32Array::from(vec![Some(1), None]);
+/// let array2 = Int32Array::from(vec![None, Some(2)]);
+/// let cmp = make_comparator(&array1, &array2, 
SortOptions::default()).unwrap();
+///
+/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2)
+/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2)
+/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None
+/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None
 /// ```
-// This is a factory of comparisons.
-// The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
-pub fn build_compare(left: &dyn Array, right: &dyn Array) -> 
Result<DynComparator, ArrowError> {
+///
+/// # Postgres-compatible Nested Comparison
+///
+/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value 
against a null yields
+/// a NULL, many systems, including postgres, instead apply a total ordering 
to comparison
+/// of nested nulls. That is nulls within nested types are either greater than 
any value,
+/// or less than any value (Spark). This could be implemented as below
+///
+/// ```
+/// # use arrow_array::{Array, BooleanArray};
+/// # use arrow_buffer::NullBuffer;
+/// # use arrow_ord::cmp;
+/// # use arrow_ord::ord::make_comparator;
+/// # use arrow_schema::{ArrowError, SortOptions};
+/// fn eq(a: &dyn Array, b: &dyn Array) -> Result<BooleanArray, ArrowError> {

Review Comment:
   I debated adding SortOptions to the existing comparison kernels, but that 
would not only be breaking change, but is also somewhat incoherent. Whilst I 
accept postgres and some other systems do this, it is rather surprising, at 
least to me.
   
   Having them separate also makes it more obvious what has an optimised 
vectorised kernel and what doesn't



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to