This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 86c3fdba21 Add rank function (#4606) (#4609)
86c3fdba21 is described below
commit 86c3fdba211762de44dbff0a9578cb6c5f694af6
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue Aug 1 10:02:35 2023 +0100
Add rank function (#4606) (#4609)
* Add rank function (#4606)
* Add benchmarks
* Add inline attribute
---
arrow-ord/src/lib.rs | 1 +
arrow-ord/src/rank.rs | 195 +++++++++++++++++++++++++++++++++++++++++++
arrow/benches/sort_kernel.rs | 29 ++++++-
arrow/src/compute/kernels.rs | 2 +-
arrow/src/compute/mod.rs | 1 +
5 files changed, 223 insertions(+), 5 deletions(-)
diff --git a/arrow-ord/src/lib.rs b/arrow-ord/src/lib.rs
index 62338c0223..8b43cdb0bf 100644
--- a/arrow-ord/src/lib.rs
+++ b/arrow-ord/src/lib.rs
@@ -46,4 +46,5 @@
pub mod comparison;
pub mod ord;
pub mod partition;
+pub mod rank;
pub mod sort;
diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs
new file mode 100644
index 0000000000..1e79156a71
--- /dev/null
+++ b/arrow-ord/src/rank.rs
@@ -0,0 +1,195 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::*;
+use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp,
GenericByteArray};
+use arrow_buffer::NullBuffer;
+use arrow_schema::{ArrowError, DataType, SortOptions};
+use std::cmp::Ordering;
+
+/// Assigns a rank to each value in `array` based on its position in the
sorted order
+///
+/// Where values are equal, they will be assigned the highest of their ranks,
+/// leaving gaps in the overall rank assignment
+///
+/// ```
+/// # use arrow_array::StringArray;
+/// # use arrow_ord::rank::rank;
+/// let array = StringArray::from(vec![Some("foo"), None, Some("foo"), None,
Some("bar")]);
+/// let ranks = rank(&array, None).unwrap();
+/// assert_eq!(ranks, &[5, 2, 5, 2, 3]);
+/// ```
+pub fn rank(
+ array: &dyn Array,
+ options: Option<SortOptions>,
+) -> Result<Vec<u32>, ArrowError> {
+ let options = options.unwrap_or_default();
+ let ranks = downcast_primitive_array! {
+ array => primitive_rank(array.values(), array.nulls(), options),
+ DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
+ DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(),
options),
+ DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(),
options),
+ DataType::LargeBinary =>
bytes_rank(array.as_bytes::<LargeBinaryType>(), options),
+ d => return Err(ArrowError::ComputeError(format!("{d:?} not supported
in rank")))
+ };
+ Ok(ranks)
+}
+
+#[inline(never)]
+fn primitive_rank<T: ArrowNativeTypeOp>(
+ values: &[T],
+ nulls: Option<&NullBuffer>,
+ options: SortOptions,
+) -> Vec<u32> {
+ let len: u32 = values.len().try_into().unwrap();
+ let to_sort = match nulls.filter(|n| n.null_count() > 0) {
+ Some(n) => n
+ .valid_indices()
+ .map(|idx| (values[idx], idx as u32))
+ .collect(),
+ None => values.iter().copied().zip(0..len).collect(),
+ };
+ rank_impl(values.len(), to_sort, options, T::compare, T::is_eq)
+}
+
+#[inline(never)]
+fn bytes_rank<T: ByteArrayType>(
+ array: &GenericByteArray<T>,
+ options: SortOptions,
+) -> Vec<u32> {
+ let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n|
n.null_count() > 0) {
+ Some(n) => n
+ .valid_indices()
+ .map(|idx| (array.value(idx).as_ref(), idx as u32))
+ .collect(),
+ None => (0..array.len())
+ .map(|idx| (array.value(idx).as_ref(), idx as u32))
+ .collect(),
+ };
+ rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq)
+}
+
+fn rank_impl<T, C, E>(
+ len: usize,
+ mut valid: Vec<(T, u32)>,
+ options: SortOptions,
+ compare: C,
+ eq: E,
+) -> Vec<u32>
+where
+ T: Copy,
+ C: Fn(T, T) -> Ordering,
+ E: Fn(T, T) -> bool,
+{
+ // We can use an unstable sort as we combine equal values later
+ valid.sort_unstable_by(|a, b| compare(a.0, b.0));
+ if options.descending {
+ valid.reverse();
+ }
+
+ let (mut valid_rank, null_rank) = match options.nulls_first {
+ true => (len as u32, (len - valid.len()) as u32),
+ false => (valid.len() as u32, len as u32),
+ };
+
+ let mut out: Vec<_> = vec![null_rank; len];
+ if let Some(v) = valid.last() {
+ out[v.1 as usize] = valid_rank;
+ }
+
+ let mut count = 1; // Number of values in rank
+ for w in valid.windows(2).rev() {
+ match eq(w[0].0, w[1].0) {
+ true => {
+ count += 1;
+ out[w[0].1 as usize] = valid_rank;
+ }
+ false => {
+ valid_rank -= count;
+ count = 1;
+ out[w[0].1 as usize] = valid_rank
+ }
+ }
+ }
+
+ out
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow_array::*;
+
+ #[test]
+ fn test_primitive() {
+ let descending = SortOptions {
+ descending: true,
+ nulls_first: true,
+ };
+
+ let nulls_last = SortOptions {
+ descending: false,
+ nulls_first: false,
+ };
+
+ let nulls_last_descending = SortOptions {
+ descending: true,
+ nulls_first: false,
+ };
+
+ let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3),
Some(3), Some(4)]);
+ let res = rank(&a, None).unwrap();
+ assert_eq!(res, &[3, 3, 1, 5, 5, 6]);
+
+ let res = rank(&a, Some(descending)).unwrap();
+ assert_eq!(res, &[6, 6, 1, 4, 4, 2]);
+
+ let res = rank(&a, Some(nulls_last)).unwrap();
+ assert_eq!(res, &[2, 2, 6, 4, 4, 5]);
+
+ let res = rank(&a, Some(nulls_last_descending)).unwrap();
+ assert_eq!(res, &[5, 5, 6, 3, 3, 1]);
+
+ // Test with non-zero null values
+ let nulls = NullBuffer::from(vec![true, true, false, true, false,
false]);
+ let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls));
+ let res = rank(&a, None).unwrap();
+ assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
+ }
+
+ #[test]
+ fn test_bytes() {
+ let v = vec!["foo", "fo", "bar", "bar"];
+ let values = StringArray::from(v.clone());
+ let res = rank(&values, None).unwrap();
+ assert_eq!(res, &[4, 3, 2, 2]);
+
+ let values = LargeStringArray::from(v.clone());
+ let res = rank(&values, None).unwrap();
+ assert_eq!(res, &[4, 3, 2, 2]);
+
+ let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]];
+ let values = LargeBinaryArray::from(v.clone());
+ let res = rank(&values, None).unwrap();
+ assert_eq!(res, &[3, 1, 4, 3]);
+
+ let values = BinaryArray::from(v);
+ let res = rank(&values, None).unwrap();
+ assert_eq!(res, &[3, 1, 4, 3]);
+ }
+}
diff --git a/arrow/benches/sort_kernel.rs b/arrow/benches/sort_kernel.rs
index 8762d9eb2f..3a3ce4462d 100644
--- a/arrow/benches/sort_kernel.rs
+++ b/arrow/benches/sort_kernel.rs
@@ -17,7 +17,7 @@
#[macro_use]
extern crate criterion;
-use criterion::Criterion;
+use criterion::{black_box, Criterion};
use std::sync::Arc;
@@ -27,6 +27,7 @@ use arrow::compute::{lexsort, sort, sort_to_indices,
SortColumn};
use arrow::datatypes::{Int16Type, Int32Type};
use arrow::util::bench_util::*;
use arrow::{array::*, datatypes::Float32Type};
+use arrow_ord::rank::rank;
fn create_f32_array(size: usize, with_nulls: bool) -> ArrayRef {
let null_density = if with_nulls { 0.5 } else { 0.0 };
@@ -42,7 +43,7 @@ fn create_bool_array(size: usize, with_nulls: bool) ->
ArrayRef {
}
fn bench_sort(array: &dyn Array) {
- criterion::black_box(sort(array, None).unwrap());
+ black_box(sort(array, None).unwrap());
}
fn bench_lexsort(array_a: &ArrayRef, array_b: &ArrayRef, limit: Option<usize>)
{
@@ -57,11 +58,11 @@ fn bench_lexsort(array_a: &ArrayRef, array_b: &ArrayRef,
limit: Option<usize>) {
},
];
- criterion::black_box(lexsort(&columns, limit).unwrap());
+ black_box(lexsort(&columns, limit).unwrap());
}
fn bench_sort_to_indices(array: &dyn Array, limit: Option<usize>) {
- criterion::black_box(sort_to_indices(array, None, limit).unwrap());
+ black_box(sort_to_indices(array, None, limit).unwrap());
}
fn add_benchmark(c: &mut Criterion) {
@@ -199,6 +200,26 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function("lexsort (f32, f32) nulls 2^12 limit 2^12", |b| {
b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(2usize.pow(12))))
});
+
+ let arr = create_f32_array(2usize.pow(12), false);
+ c.bench_function("rank f32 2^12", |b| {
+ b.iter(|| black_box(rank(&arr, None).unwrap()))
+ });
+
+ let arr = create_f32_array(2usize.pow(12), true);
+ c.bench_function("rank f32 nulls 2^12", |b| {
+ b.iter(|| black_box(rank(&arr, None).unwrap()))
+ });
+
+ let arr = create_string_array_with_len::<i32>(2usize.pow(12), 0.0, 10);
+ c.bench_function("rank string[10] 2^12", |b| {
+ b.iter(|| black_box(rank(&arr, None).unwrap()))
+ });
+
+ let arr = create_string_array_with_len::<i32>(2usize.pow(12), 0.5, 10);
+ c.bench_function("rank string[10] nulls 2^12", |b| {
+ b.iter(|| black_box(rank(&arr, None).unwrap()))
+ });
}
criterion_group!(benches, add_benchmark);
diff --git a/arrow/src/compute/kernels.rs b/arrow/src/compute/kernels.rs
index 1a79aef547..faff1b8a0d 100644
--- a/arrow/src/compute/kernels.rs
+++ b/arrow/src/compute/kernels.rs
@@ -22,7 +22,7 @@ pub use arrow_arith::{
};
pub use arrow_cast::cast;
pub use arrow_cast::parse as cast_utils;
-pub use arrow_ord::{partition, sort};
+pub use arrow_ord::{partition, rank, sort};
pub use arrow_select::{concat, filter, interleave, nullif, take, window, zip};
pub use arrow_string::{concat_elements, length, regexp, substring};
diff --git a/arrow/src/compute/mod.rs b/arrow/src/compute/mod.rs
index 7cfe787b08..47a9d149aa 100644
--- a/arrow/src/compute/mod.rs
+++ b/arrow/src/compute/mod.rs
@@ -30,6 +30,7 @@ pub use self::kernels::filter::*;
pub use self::kernels::interleave::*;
pub use self::kernels::nullif::*;
pub use self::kernels::partition::*;
+pub use self::kernels::rank::*;
pub use self::kernels::regexp::*;
pub use self::kernels::sort::*;
pub use self::kernels::take::*;