This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 9828bf0bcd3 Push SortOptions into DynComparator Allowing Nested
Comparisons (#5426) (#5792)
9828bf0bcd3 is described below
commit 9828bf0bcd3e54ff5c51154ee99d183d9ee171fa
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue May 28 19:03:07 2024 +0100
Push SortOptions into DynComparator Allowing Nested Comparisons (#5426)
(#5792)
* Push SortOptions into DynComparator (#5426)
* Clippy
* Review feedback
* Tweak print_row
---
arrow-ord/src/ord.rs | 551 ++++++++++++++++++++++++++++++++++++++++++-------
arrow-ord/src/sort.rs | 114 ++--------
arrow-row/src/lib.rs | 53 ++++-
arrow/src/array/mod.rs | 3 +-
4 files changed, 542 insertions(+), 179 deletions(-)
diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs
index 8f21cd7c498..3825e5ec66f 100644
--- a/arrow-ord/src/ord.rs
+++ b/arrow-ord/src/ord.rs
@@ -20,36 +20,117 @@
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
-use arrow_buffer::ArrowNativeType;
-use arrow_schema::ArrowError;
+use arrow_buffer::{ArrowNativeType, NullBuffer};
+use arrow_schema::{ArrowError, SortOptions};
use std::cmp::Ordering;
/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
-fn compare_primitive<T: ArrowPrimitiveType>(left: &dyn Array, right: &dyn
Array) -> DynComparator
+/// If parent sort order is descending we need to invert the value of
nulls_first so that
+/// when the parent is sorted based on the produced ranks, nulls are still
ordered correctly
+fn child_opts(opts: SortOptions) -> SortOptions {
+ SortOptions {
+ descending: false,
+ nulls_first: opts.nulls_first != opts.descending,
+ }
+}
+
+fn compare<A, F>(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator
where
- T::Native: ArrowNativeTypeOp,
+ A: Array + Clone,
+ F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
{
- let left = left.as_primitive::<T>().clone();
- let right = right.as_primitive::<T>().clone();
- Box::new(move |i, j| left.value(i).compare(right.value(j)))
+ let l = l.logical_nulls().filter(|x| x.null_count() > 0);
+ let r = r.logical_nulls().filter(|x| x.null_count() > 0);
+ match (opts.nulls_first, opts.descending) {
+ (true, true) => compare_impl::<true, true, _>(l, r, cmp),
+ (true, false) => compare_impl::<true, false, _>(l, r, cmp),
+ (false, true) => compare_impl::<false, true, _>(l, r, cmp),
+ (false, false) => compare_impl::<false, false, _>(l, r, cmp),
+ }
}
-fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
- let left: BooleanArray = left.as_boolean().clone();
- let right: BooleanArray = right.as_boolean().clone();
+fn compare_impl<const NULLS_FIRST: bool, const DESCENDING: bool, F>(
+ l: Option<NullBuffer>,
+ r: Option<NullBuffer>,
+ cmp: F,
+) -> DynComparator
+where
+ F: Fn(usize, usize) -> Ordering + Send + Sync + 'static,
+{
+ let cmp = move |i, j| match DESCENDING {
+ true => cmp(i, j).reverse(),
+ false => cmp(i, j),
+ };
+
+ let (left_null, right_null) = match NULLS_FIRST {
+ true => (Ordering::Less, Ordering::Greater),
+ false => (Ordering::Greater, Ordering::Less),
+ };
+
+ match (l, r) {
+ (None, None) => Box::new(cmp),
+ (Some(l), None) => Box::new(move |i, j| match l.is_null(i) {
+ true => left_null,
+ false => cmp(i, j),
+ }),
+ (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) {
+ true => right_null,
+ false => cmp(i, j),
+ }),
+ (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i),
r.is_null(j)) {
+ (true, true) => Ordering::Equal,
+ (true, false) => left_null,
+ (false, true) => right_null,
+ (false, false) => cmp(i, j),
+ }),
+ }
+}
+
+fn compare_primitive<T: ArrowPrimitiveType>(
+ left: &dyn Array,
+ right: &dyn Array,
+ opts: SortOptions,
+) -> DynComparator
+where
+ T::Native: ArrowNativeTypeOp,
+{
+ let left = left.as_primitive::<T>();
+ let right = right.as_primitive::<T>();
+ let l_values = left.values().clone();
+ let r_values = right.values().clone();
- Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
+ compare(&left, &right, opts, move |i, j| {
+ l_values[i].compare(r_values[j])
+ })
}
-fn compare_bytes<T: ByteArrayType>(left: &dyn Array, right: &dyn Array) ->
DynComparator {
- let left = left.as_bytes::<T>().clone();
- let right = right.as_bytes::<T>().clone();
+fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) ->
DynComparator {
+ let left = left.as_boolean();
+ let right = right.as_boolean();
+
+ let l_values = left.values().clone();
+ let r_values = right.values().clone();
- Box::new(move |i, j| {
- let l: &[u8] = left.value(i).as_ref();
- let r: &[u8] = right.value(j).as_ref();
+ compare(left, right, opts, move |i, j| {
+ l_values.value(i).cmp(&r_values.value(j))
+ })
+}
+
+fn compare_bytes<T: ByteArrayType>(
+ left: &dyn Array,
+ right: &dyn Array,
+ opts: SortOptions,
+) -> DynComparator {
+ let left = left.as_bytes::<T>();
+ let right = right.as_bytes::<T>();
+
+ let l = left.clone();
+ let r = right.clone();
+ compare(left, right, opts, move |i, j| {
+ let l: &[u8] = l.value(i).as_ref();
+ let r: &[u8] = r.value(j).as_ref();
l.cmp(r)
})
}
@@ -57,67 +138,234 @@ fn compare_bytes<T: ByteArrayType>(left: &dyn Array,
right: &dyn Array) -> DynCo
fn compare_dict<K: ArrowDictionaryKeyType>(
left: &dyn Array,
right: &dyn Array,
+ opts: SortOptions,
) -> Result<DynComparator, ArrowError> {
let left = left.as_dictionary::<K>();
let right = right.as_dictionary::<K>();
- let cmp = build_compare(left.values().as_ref(), right.values().as_ref())?;
- let left_keys = left.keys().clone();
- let right_keys = right.keys().clone();
+ let c_opts = child_opts(opts);
+ let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(),
c_opts)?;
+ let left_keys = left.keys().values().clone();
+ let right_keys = right.keys().values().clone();
- // TODO: Handle value nulls (#2687)
- Ok(Box::new(move |i, j| {
- let l = left_keys.value(i).as_usize();
- let r = right_keys.value(j).as_usize();
+ let f = compare(left, right, opts, move |i, j| {
+ let l = left_keys[i].as_usize();
+ let r = right_keys[j].as_usize();
cmp(l, r)
- }))
+ });
+ Ok(f)
}
-/// returns a comparison function that compares two values at two different
positions
+fn compare_list<O: OffsetSizeTrait>(
+ left: &dyn Array,
+ right: &dyn Array,
+ opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+ let left = left.as_list::<O>();
+ let right = right.as_list::<O>();
+
+ let c_opts = child_opts(opts);
+ let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(),
c_opts)?;
+
+ let l_o = left.offsets().clone();
+ let r_o = right.offsets().clone();
+ let f = compare(left, right, opts, move |i, j| {
+ let l_end = l_o[i + 1].as_usize();
+ let l_start = l_o[i].as_usize();
+
+ let r_end = r_o[j + 1].as_usize();
+ let r_start = r_o[j].as_usize();
+
+ for (i, j) in (l_start..l_end).zip(r_start..r_end) {
+ match cmp(i, j) {
+ Ordering::Equal => continue,
+ r => return r,
+ }
+ }
+ (l_end - l_start).cmp(&(r_end - r_start))
+ });
+ Ok(f)
+}
+
+fn compare_fixed_list(
+ left: &dyn Array,
+ right: &dyn Array,
+ opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+ let left = left.as_fixed_size_list();
+ let right = right.as_fixed_size_list();
+
+ let c_opts = child_opts(opts);
+ let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(),
c_opts)?;
+
+ let l_size = left.value_length().to_usize().unwrap();
+ let r_size = right.value_length().to_usize().unwrap();
+ let size_cmp = l_size.cmp(&r_size);
+
+ let f = compare(left, right, opts, move |i, j| {
+ let l_start = i * l_size;
+ let l_end = l_start + l_size;
+ let r_start = j * r_size;
+ let r_end = r_start + r_size;
+ for (i, j) in (l_start..l_end).zip(r_start..r_end) {
+ match cmp(i, j) {
+ Ordering::Equal => continue,
+ r => return r,
+ }
+ }
+ size_cmp
+ });
+ Ok(f)
+}
+
+fn compare_struct(
+ left: &dyn Array,
+ right: &dyn Array,
+ opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
+ let left = left.as_struct();
+ let right = right.as_struct();
+
+ if left.columns().len() != right.columns().len() {
+ return Err(ArrowError::InvalidArgumentError(
+ "Cannot compare StructArray with different number of
columns".to_string(),
+ ));
+ }
+
+ let c_opts = child_opts(opts);
+ let columns = left.columns().iter().zip(right.columns());
+ let comparators = columns
+ .map(|(l, r)| make_comparator(l, r, c_opts))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let f = compare(left, right, opts, move |i, j| {
+ for cmp in &comparators {
+ match cmp(i, j) {
+ Ordering::Equal => continue,
+ r => return r,
+ }
+ }
+ Ordering::Equal
+ });
+ Ok(f)
+}
+
+#[deprecated(note = "Use make_comparator")]
+#[doc(hidden)]
+pub fn build_compare(left: &dyn Array, right: &dyn Array) ->
Result<DynComparator, ArrowError> {
+ make_comparator(left, right, SortOptions::default())
+}
+
+/// Returns a comparison function that compares two values at two different
positions
/// between the two arrays.
-/// The arrays' types must be equal.
-/// # Example
-/// ```
-/// use arrow_array::Int32Array;
-/// use arrow_ord::ord::build_compare;
///
+/// For comparing arrays element-wise, see also the vectorised kernels in
[`crate::cmp`].
+///
+/// If `nulls_first` is true `NULL` values will be considered less than any
non-null value,
+/// otherwise they will be considered greater.
+///
+/// # Basic Usage
+///
+/// ```
+/// # use std::cmp::Ordering;
+/// # use arrow_array::Int32Array;
+/// # use arrow_ord::ord::make_comparator;
+/// # use arrow_schema::SortOptions;
+/// #
/// let array1 = Int32Array::from(vec![1, 2]);
/// let array2 = Int32Array::from(vec![3, 4]);
///
-/// let cmp = build_compare(&array1, &array2).unwrap();
-///
+/// let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
-/// assert_eq!(std::cmp::Ordering::Less, cmp(0, 1));
+/// assert_eq!(cmp(0, 1), Ordering::Less);
+///
+/// let array1 = Int32Array::from(vec![Some(1), None]);
+/// let array2 = Int32Array::from(vec![None, Some(2)]);
+/// let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
+///
+/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2)
+/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2)
+/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None
+/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None
/// ```
-// This is a factory of comparisons.
-// The lifetime 'a enforces that we cannot use the closure beyond any of the
array's lifetime.
-pub fn build_compare(left: &dyn Array, right: &dyn Array) ->
Result<DynComparator, ArrowError> {
+///
+/// # Postgres-compatible Nested Comparison
+///
+/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value
against a NULL yields
+/// a NULL, many systems, including postgres, instead apply a total ordering
to comparison of
+/// nested nulls. That is nulls within nested types are either greater than
any value (postgres),
+/// or less than any value (Spark).
+///
+/// In particular
+///
+/// ```ignore
+/// { a: 1, b: null } == { a: 1, b: null } => true
+/// { a: 1, b: null } == { a: 1, b: 1 } => false
+/// { a: 1, b: null } == null => null
+/// null == null => null
+/// ```
+///
+/// This could be implemented as below
+///
+/// ```
+/// # use arrow_array::{Array, BooleanArray};
+/// # use arrow_buffer::NullBuffer;
+/// # use arrow_ord::cmp;
+/// # use arrow_ord::ord::make_comparator;
+/// # use arrow_schema::{ArrowError, SortOptions};
+/// fn eq(a: &dyn Array, b: &dyn Array) -> Result<BooleanArray, ArrowError> {
+/// if !a.data_type().is_nested() {
+/// return cmp::eq(&a, &b); // Use faster vectorised kernel
+/// }
+///
+/// let cmp = make_comparator(a, b, SortOptions::default())?;
+/// let len = a.len().min(b.len());
+/// let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
+/// let nulls = NullBuffer::union(a.nulls(), b.nulls());
+/// Ok(BooleanArray::new(values, nulls))
+/// }
+/// ````
+pub fn make_comparator(
+ left: &dyn Array,
+ right: &dyn Array,
+ opts: SortOptions,
+) -> Result<DynComparator, ArrowError> {
use arrow_schema::DataType::*;
+
macro_rules! primitive_helper {
- ($t:ty, $left:expr, $right:expr) => {
- Ok(compare_primitive::<$t>($left, $right))
+ ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => {
+ Ok(compare_primitive::<$t>($left, $right, $nulls_first))
};
}
downcast_primitive! {
- left.data_type(), right.data_type() => (primitive_helper, left, right),
- (Boolean, Boolean) => Ok(compare_boolean(left, right)),
- (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right)),
- (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left,
right)),
- (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right)),
- (LargeBinary, LargeBinary) =>
Ok(compare_bytes::<LargeBinaryType>(left, right)),
+ left.data_type(), right.data_type() => (primitive_helper, left, right,
opts),
+ (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)),
+ (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right, opts)),
+ (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left,
right, opts)),
+ (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right, opts)),
+ (LargeBinary, LargeBinary) =>
Ok(compare_bytes::<LargeBinaryType>(left, right, opts)),
(FixedSizeBinary(_), FixedSizeBinary(_)) => {
- let left = left.as_fixed_size_binary().clone();
- let right = right.as_fixed_size_binary().clone();
- Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j))))
+ let left = left.as_fixed_size_binary();
+ let right = right.as_fixed_size_binary();
+
+ let l = left.clone();
+ let r = right.clone();
+ Ok(compare(left, right, opts, move |i, j| {
+ l.value(i).cmp(r.value(j))
+ }))
},
+ (List(_), List(_)) => compare_list::<i32>(left, right, opts),
+ (LargeList(_), LargeList(_)) => compare_list::<i64>(left, right, opts),
+ (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left,
right, opts),
+ (Struct(_), Struct(_)) => compare_struct(left, right, opts),
(Dictionary(l_key, _), Dictionary(r_key, _)) => {
macro_rules! dict_helper {
- ($t:ty, $left:expr, $right:expr) => {
- compare_dict::<$t>($left, $right)
+ ($t:ty, $left:expr, $right:expr, $opts: expr) => {
+ compare_dict::<$t>($left, $right, $opts)
};
}
downcast_integer! {
- l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right),
+ l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right,
opts),
_ => unreachable!()
}
},
@@ -131,7 +379,9 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array)
-> Result<DynComparato
#[cfg(test)]
pub mod tests {
use super::*;
+ use arrow_array::builder::{Int32Builder, ListBuilder};
use arrow_buffer::{i256, IntervalDayTime, OffsetBuffer};
+ use arrow_schema::{DataType, Field, Fields};
use half::f16;
use std::sync::Arc;
@@ -140,7 +390,7 @@ pub mod tests {
let items = vec![vec![1u8], vec![2u8]];
let array =
FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
}
@@ -152,7 +402,7 @@ pub mod tests {
let items = vec![vec![2u8]];
let array2 =
FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
}
@@ -161,7 +411,7 @@ pub mod tests {
fn test_i32() {
let array = Int32Array::from(vec![1, 2]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, (cmp)(0, 1));
}
@@ -171,7 +421,7 @@ pub mod tests {
let array1 = Int32Array::from(vec![1]);
let array2 = Int32Array::from(vec![2]);
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
}
@@ -180,7 +430,7 @@ pub mod tests {
fn test_f16() {
let array = Float16Array::from(vec![f16::from_f32(1.0),
f16::from_f32(2.0)]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
}
@@ -189,7 +439,7 @@ pub mod tests {
fn test_f64() {
let array = Float64Array::from(vec![1.0, 2.0]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
}
@@ -198,7 +448,7 @@ pub mod tests {
fn test_f64_nan() {
let array = Float64Array::from(vec![1.0, f64::NAN]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
assert_eq!(Ordering::Equal, cmp(1, 1));
@@ -208,7 +458,7 @@ pub mod tests {
fn test_f64_zeros() {
let array = Float64Array::from(vec![-0.0, 0.0]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -225,7 +475,7 @@ pub mod tests {
IntervalDayTimeType::make_value(0, 90_000_000),
]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -248,7 +498,7 @@ pub mod tests {
IntervalYearMonthType::make_value(1, 1),
]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -269,7 +519,7 @@ pub mod tests {
IntervalMonthDayNanoType::make_value(0, 100, 2),
]);
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
assert_eq!(Ordering::Greater, cmp(1, 0));
@@ -289,7 +539,7 @@ pub mod tests {
.with_precision_and_scale(23, 6)
.unwrap();
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(1, 0));
assert_eq!(Ordering::Greater, cmp(0, 2));
}
@@ -306,7 +556,7 @@ pub mod tests {
.with_precision_and_scale(53, 6)
.unwrap();
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(1, 0));
assert_eq!(Ordering::Greater, cmp(0, 2));
}
@@ -316,7 +566,7 @@ pub mod tests {
let data = vec!["a", "b", "c", "a", "a", "c", "c"];
let array = data.into_iter().collect::<DictionaryArray<Int16Type>>();
- let cmp = build_compare(&array, &array).unwrap();
+ let cmp = make_comparator(&array, &array,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
assert_eq!(Ordering::Equal, cmp(3, 4));
@@ -330,7 +580,7 @@ pub mod tests {
let d2 = vec!["e", "f", "g", "a"];
let a2 = d2.into_iter().collect::<DictionaryArray<Int16Type>>();
- let cmp = build_compare(&a1, &a2).unwrap();
+ let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
assert_eq!(Ordering::Equal, cmp(0, 3));
@@ -347,7 +597,7 @@ pub mod tests {
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::new(keys, Arc::new(values));
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
assert_eq!(Ordering::Less, cmp(0, 3));
@@ -366,7 +616,7 @@ pub mod tests {
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::new(keys, Arc::new(values));
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
assert_eq!(Ordering::Less, cmp(0, 3));
@@ -385,7 +635,7 @@ pub mod tests {
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::new(keys, Arc::new(values));
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
assert_eq!(Ordering::Less, cmp(0, 3));
@@ -408,7 +658,7 @@ pub mod tests {
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::new(keys, Arc::new(values));
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0)); // v1 vs v3
assert_eq!(Ordering::Equal, cmp(0, 3)); // v1 vs v1
@@ -427,7 +677,7 @@ pub mod tests {
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::new(keys, Arc::new(values));
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
assert_eq!(Ordering::Less, cmp(0, 3));
@@ -446,7 +696,7 @@ pub mod tests {
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::new(keys, Arc::new(values));
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
assert_eq!(Ordering::Less, cmp(0, 3));
@@ -475,7 +725,7 @@ pub mod tests {
let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
let array2 = DictionaryArray::new(keys, Arc::new(values));
- let cmp = build_compare(&array1, &array2).unwrap();
+ let cmp = make_comparator(&array1, &array2,
SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 0));
assert_eq!(Ordering::Less, cmp(0, 3));
@@ -487,7 +737,7 @@ pub mod tests {
fn test_bytes_impl<T: ByteArrayType>() {
let offsets = OffsetBuffer::from_lengths([3, 3, 1]);
let a = GenericByteArray::<T>::new(offsets, b"abcdefa".into(), None);
- let cmp = build_compare(&a, &a).unwrap();
+ let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap();
assert_eq!(Ordering::Less, cmp(0, 1));
assert_eq!(Ordering::Greater, cmp(0, 2));
@@ -501,4 +751,157 @@ pub mod tests {
test_bytes_impl::<BinaryType>();
test_bytes_impl::<LargeBinaryType>();
}
+
+ #[test]
+ fn test_lists() {
+ let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
+ a.extend([
+ Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
+ Some(vec![
+ Some(vec![Some(1), Some(2), Some(3)]),
+ Some(vec![Some(1)]),
+ ]),
+ Some(vec![]),
+ ]);
+ let a = a.finish();
+ let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
+ b.extend([
+ Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
+ Some(vec![
+ Some(vec![Some(1), Some(2), None]),
+ Some(vec![Some(1)]),
+ ]),
+ Some(vec![
+ Some(vec![Some(1), Some(2), Some(3), Some(4)]),
+ Some(vec![Some(1)]),
+ ]),
+ None,
+ ]);
+ let b = b.finish();
+
+ let opts = SortOptions {
+ descending: false,
+ nulls_first: true,
+ };
+ let cmp = make_comparator(&a, &b, opts).unwrap();
+ assert_eq!(cmp(0, 0), Ordering::Equal);
+ assert_eq!(cmp(0, 1), Ordering::Less);
+ assert_eq!(cmp(0, 2), Ordering::Less);
+ assert_eq!(cmp(1, 2), Ordering::Less);
+ assert_eq!(cmp(1, 3), Ordering::Greater);
+ assert_eq!(cmp(2, 0), Ordering::Less);
+
+ let opts = SortOptions {
+ descending: true,
+ nulls_first: true,
+ };
+ let cmp = make_comparator(&a, &b, opts).unwrap();
+ assert_eq!(cmp(0, 0), Ordering::Equal);
+ assert_eq!(cmp(0, 1), Ordering::Less);
+ assert_eq!(cmp(0, 2), Ordering::Less);
+ assert_eq!(cmp(1, 2), Ordering::Greater);
+ assert_eq!(cmp(1, 3), Ordering::Greater);
+ assert_eq!(cmp(2, 0), Ordering::Greater);
+
+ let opts = SortOptions {
+ descending: true,
+ nulls_first: false,
+ };
+ let cmp = make_comparator(&a, &b, opts).unwrap();
+ assert_eq!(cmp(0, 0), Ordering::Equal);
+ assert_eq!(cmp(0, 1), Ordering::Greater);
+ assert_eq!(cmp(0, 2), Ordering::Greater);
+ assert_eq!(cmp(1, 2), Ordering::Greater);
+ assert_eq!(cmp(1, 3), Ordering::Less);
+ assert_eq!(cmp(2, 0), Ordering::Greater);
+
+ let opts = SortOptions {
+ descending: false,
+ nulls_first: false,
+ };
+ let cmp = make_comparator(&a, &b, opts).unwrap();
+ assert_eq!(cmp(0, 0), Ordering::Equal);
+ assert_eq!(cmp(0, 1), Ordering::Greater);
+ assert_eq!(cmp(0, 2), Ordering::Greater);
+ assert_eq!(cmp(1, 2), Ordering::Less);
+ assert_eq!(cmp(1, 3), Ordering::Less);
+ assert_eq!(cmp(2, 0), Ordering::Less);
+ }
+
+ #[test]
+ fn test_struct() {
+ let fields = Fields::from(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new_list("b", Field::new("item", DataType::Int32, true),
true),
+ ]);
+
+ let a = Int32Array::from(vec![Some(1), Some(2), None, None]);
+ let mut b = ListBuilder::new(Int32Builder::new());
+ b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]);
+ let b = b.finish();
+
+ let nulls = Some(NullBuffer::from_iter([true, true, true, false]));
+ let values = vec![Arc::new(a) as _, Arc::new(b) as _];
+ let s1 = StructArray::new(fields.clone(), values, nulls);
+
+ let a = Int32Array::from(vec![None, Some(2), None]);
+ let mut b = ListBuilder::new(Int32Builder::new());
+ b.extend([None, None, Some(vec![])]);
+ let b = b.finish();
+
+ let values = vec![Arc::new(a) as _, Arc::new(b) as _];
+ let s2 = StructArray::new(fields.clone(), values, None);
+
+ let opts = SortOptions {
+ descending: false,
+ nulls_first: true,
+ };
+ let cmp = make_comparator(&s1, &s2, opts).unwrap();
+ assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
+ assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None,
None)
+ assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
+ assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
+ assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
+ assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None,
None)
+ assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
+
+ let opts = SortOptions {
+ descending: true,
+ nulls_first: true,
+ };
+ let cmp = make_comparator(&s1, &s2, opts).unwrap();
+ assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
+ assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None,
None)
+ assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
+ assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
+ assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
+ assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None,
None)
+ assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)
+
+ let opts = SortOptions {
+ descending: true,
+ nulls_first: false,
+ };
+ let cmp = make_comparator(&s1, &s2, opts).unwrap();
+ assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
+ assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
+ assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
+ assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None,
[])
+ assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
+ assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None,
None)
+ assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
+
+ let opts = SortOptions {
+ descending: false,
+ nulls_first: false,
+ };
+ let cmp = make_comparator(&s1, &s2, opts).unwrap();
+ assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
+ assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
+ assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
+ assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None,
[])
+ assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
+ assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None,
None)
+ assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
+ }
}
diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs
index fe3a1f86ac0..8ae87787d28 100644
--- a/arrow-ord/src/sort.rs
+++ b/arrow-ord/src/sort.rs
@@ -17,13 +17,13 @@
//! Defines sort kernel for `ArrayRef`
-use crate::ord::{build_compare, DynComparator};
+use crate::ord::{make_comparator, DynComparator};
use arrow_array::builder::BufferBuilder;
use arrow_array::cast::*;
use arrow_array::types::*;
use arrow_array::*;
+use arrow_buffer::ArrowNativeType;
use arrow_buffer::BooleanBufferBuilder;
-use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType};
use arrow_select::take::take;
@@ -704,60 +704,21 @@ where
}
}
-type LexicographicalCompareItem = (
- Option<NullBuffer>, // nulls
- DynComparator, // comparator
- SortOptions, // sort_option
-);
-
/// A lexicographical comparator that wraps given array data (columns) and can
lexicographically compare data
/// at given two indices. The lifetime is the same at the data wrapped.
pub struct LexicographicalComparator {
- compare_items: Vec<LexicographicalCompareItem>,
+ compare_items: Vec<DynComparator>,
}
impl LexicographicalComparator {
/// lexicographically compare values at the wrapped columns with given
indices.
pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering {
- for (nulls, comparator, sort_option) in &self.compare_items {
- let (lhs_valid, rhs_valid) = match nulls {
- Some(n) => (n.is_valid(a_idx), n.is_valid(b_idx)),
- None => (true, true),
- };
-
- match (lhs_valid, rhs_valid) {
- (true, true) => {
- match (comparator)(a_idx, b_idx) {
- // equal, move on to next column
- Ordering::Equal => continue,
- order => {
- if sort_option.descending {
- return order.reverse();
- } else {
- return order;
- }
- }
- }
- }
- (false, true) => {
- return if sort_option.nulls_first {
- Ordering::Less
- } else {
- Ordering::Greater
- };
- }
- (true, false) => {
- return if sort_option.nulls_first {
- Ordering::Greater
- } else {
- Ordering::Less
- };
- }
- // equal, move on to next column
- (false, false) => continue,
+ for comparator in &self.compare_items {
+ match comparator(a_idx, b_idx) {
+ Ordering::Equal => continue,
+ r => return r,
}
}
-
Ordering::Equal
}
@@ -766,61 +727,16 @@ impl LexicographicalComparator {
pub fn try_new(columns: &[SortColumn]) ->
Result<LexicographicalComparator, ArrowError> {
let compare_items = columns
.iter()
- .map(Self::build_compare_item)
+ .map(|c| {
+ make_comparator(
+ c.values.as_ref(),
+ c.values.as_ref(),
+ c.options.unwrap_or_default(),
+ )
+ })
.collect::<Result<Vec<_>, ArrowError>>()?;
Ok(LexicographicalComparator { compare_items })
}
-
- fn build_compare_item(column: &SortColumn) ->
Result<LexicographicalCompareItem, ArrowError> {
- let values = column.values.as_ref();
- let options = column.options.unwrap_or_default();
- let comparator = match values.data_type() {
- DataType::List(_) =>
Self::build_list_compare(values.as_list::<i32>(), options)?,
- DataType::LargeList(_) =>
Self::build_list_compare(values.as_list::<i64>(), options)?,
- DataType::FixedSizeList(_, _) => {
-
Self::build_fixed_size_list_compare(values.as_fixed_size_list(), options)?
- }
- _ => build_compare(values, values)?,
- };
- Ok((values.logical_nulls(), comparator, options))
- }
-
- fn build_list_compare<O: OffsetSizeTrait>(
- array: &GenericListArray<O>,
- options: SortOptions,
- ) -> Result<DynComparator, ArrowError> {
- let rank = child_rank(array.values().as_ref(), options)?;
- let offsets = array.offsets().clone();
- let cmp = Box::new(move |i: usize, j: usize| {
- macro_rules! nth_value {
- ($INDEX:expr) => {{
- let end = offsets[$INDEX + 1].as_usize();
- let start = offsets[$INDEX].as_usize();
- &rank[start..end]
- }};
- }
- Ord::cmp(nth_value!(i), nth_value!(j))
- });
- Ok(cmp)
- }
-
- fn build_fixed_size_list_compare(
- array: &FixedSizeListArray,
- options: SortOptions,
- ) -> Result<DynComparator, ArrowError> {
- let rank = child_rank(array.values().as_ref(), options)?;
- let size = array.value_length() as usize;
- let cmp = Box::new(move |i: usize, j: usize| {
- macro_rules! nth_value {
- ($INDEX:expr) => {{
- let start = $INDEX * size;
- &rank[start..start + size]
- }};
- }
- Ord::cmp(nth_value!(i), nth_value!(j))
- });
- Ok(cmp)
- }
}
#[cfg(test)]
@@ -829,7 +745,7 @@ mod tests {
use arrow_array::builder::{
FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder,
};
- use arrow_buffer::i256;
+ use arrow_buffer::{i256, NullBuffer};
use half::f16;
use rand::rngs::StdRng;
use rand::{Rng, RngCore, SeedableRng};
diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs
index 4dc7349ca2d..8e1285493b0 100644
--- a/arrow-row/src/lib.rs
+++ b/arrow-row/src/lib.rs
@@ -1302,9 +1302,9 @@ mod tests {
use arrow_array::builder::*;
use arrow_array::types::*;
use arrow_array::*;
- use arrow_buffer::i256;
- use arrow_buffer::Buffer;
- use arrow_cast::display::array_value_to_string;
+ use arrow_buffer::{i256, NullBuffer};
+ use arrow_buffer::{Buffer, OffsetBuffer};
+ use arrow_cast::display::{ArrayFormatter, FormatOptions};
use arrow_ord::sort::{LexicographicalComparator, SortColumn};
use super::*;
@@ -2099,9 +2099,35 @@ mod tests {
builder.finish()
}
+ fn generate_struct(len: usize, valid_percent: f64) -> StructArray {
+ let mut rng = thread_rng();
+ let nulls = NullBuffer::from_iter((0..len).map(|_|
rng.gen_bool(valid_percent)));
+ let a = generate_primitive_array::<Int32Type>(len, valid_percent);
+ let b = generate_strings::<i32>(len, valid_percent);
+ let fields = Fields::from(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Utf8, true),
+ ]);
+ let values = vec![Arc::new(a) as _, Arc::new(b) as _];
+ StructArray::new(fields, values, Some(nulls))
+ }
+
+ fn generate_list<F>(len: usize, valid_percent: f64, values: F) -> ListArray
+ where
+ F: FnOnce(usize) -> ArrayRef,
+ {
+ let mut rng = thread_rng();
+ let offsets = OffsetBuffer::<i32>::from_lengths((0..len).map(|_|
rng.gen_range(0..10)));
+ let values_len = offsets.last().unwrap().to_usize().unwrap();
+ let values = values(values_len);
+ let nulls = NullBuffer::from_iter((0..len).map(|_|
rng.gen_bool(valid_percent)));
+ let field = Arc::new(Field::new("item", values.data_type().clone(),
true));
+ ListArray::new(field, offsets, values, Some(nulls))
+ }
+
fn generate_column(len: usize) -> ArrayRef {
let mut rng = thread_rng();
- match rng.gen_range(0..10) {
+ match rng.gen_range(0..14) {
0 => Arc::new(generate_primitive_array::<Int32Type>(len, 0.8)),
1 => Arc::new(generate_primitive_array::<UInt32Type>(len, 0.8)),
2 => Arc::new(generate_primitive_array::<Int64Type>(len, 0.8)),
@@ -2125,6 +2151,16 @@ mod tests {
0.8,
)),
9 => Arc::new(generate_fixed_size_binary(len, 0.8)),
+ 10 => Arc::new(generate_struct(len, 0.8)),
+ 11 => Arc::new(generate_list(len, 0.8, |values_len| {
+ Arc::new(generate_primitive_array::<Int64Type>(values_len,
0.8))
+ })),
+ 12 => Arc::new(generate_list(len, 0.8, |values_len| {
+ Arc::new(generate_strings::<i32>(values_len, 0.8))
+ })),
+ 13 => Arc::new(generate_list(len, 0.8, |values_len| {
+ Arc::new(generate_struct(values_len, 0.8))
+ })),
_ => unreachable!(),
}
}
@@ -2132,7 +2168,14 @@ mod tests {
fn print_row(cols: &[SortColumn], row: usize) -> String {
let t: Vec<_> = cols
.iter()
- .map(|x| array_value_to_string(&x.values, row).unwrap())
+ .map(|x| match x.values.is_valid(row) {
+ true => {
+ let opts = FormatOptions::default().with_null("NULL");
+ let formatter = ArrayFormatter::try_new(x.values.as_ref(),
&opts).unwrap();
+ formatter.value(row).to_string()
+ }
+ false => "NULL".to_string(),
+ })
.collect();
t.join(",")
}
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index b563c320bb6..242c9148cac 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -36,4 +36,5 @@ pub use arrow_array::ffi::export_array_into_raw;
// --------------------- Array's values comparison ---------------------
-pub use arrow_ord::ord::{build_compare, DynComparator};
+#[allow(deprecated)]
+pub use arrow_ord::ord::{build_compare, make_comparator, DynComparator};