This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new af40ea382 Implement specialized min/max for `GenericBinaryView`
(`StringView` and `BinaryView`) (#6089)
af40ea382 is described below
commit af40ea382275dba967bfabc1632fded07d2129b9
Author: Xiangpeng Hao <[email protected]>
AuthorDate: Tue Jul 23 15:48:25 2024 -0400
Implement specialized min/max for `GenericBinaryView` (`StringView` and
`BinaryView`) (#6089)
* implement better min/max for string view
* Apply suggestions from code review
Co-authored-by: Andrew Lamb <[email protected]>
* address review comments
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-arith/src/aggregate.rs | 51 ++++++++++++++++++++++++---
arrow-array/src/array/byte_view_array.rs | 60 ++++++++++++++++++++++++++++++++
arrow-ord/src/cmp.rs | 7 ++--
arrow/benches/aggregate_kernels.rs | 23 ++++++++++--
4 files changed, 132 insertions(+), 9 deletions(-)
diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs
index f526268bf..a81e7fc1a 100644
--- a/arrow-arith/src/aggregate.rs
+++ b/arrow-arith/src/aggregate.rs
@@ -24,7 +24,9 @@ use arrow_buffer::{ArrowNativeType, NullBuffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_schema::*;
use std::borrow::BorrowMut;
+use std::cmp::{self, Ordering};
use std::ops::{BitAnd, BitOr, BitXor};
+use types::ByteViewType;
/// An accumulator for primitive numeric values.
trait NumericAccumulator<T: ArrowNativeTypeOp>: Copy + Default {
@@ -425,6 +427,47 @@ where
}
}
+/// Helper to compute min/max of [`GenericByteViewArray<T>`].
+/// The specialized min/max leverages the inlined values to compare the byte
views.
+/// `swap_cond` is the condition to swap current min/max with the new value.
+/// For example, `Ordering::Greater` for max and `Ordering::Less` for min.
+fn min_max_view_helper<T: ByteViewType>(
+ array: &GenericByteViewArray<T>,
+ swap_cond: cmp::Ordering,
+) -> Option<&T::Native> {
+ let null_count = array.null_count();
+ if null_count == array.len() {
+ None
+ } else if null_count == 0 {
+ let target_idx = (0..array.len()).reduce(|acc, item| {
+ // SAFETY: array's length is correct so item is within bounds
+ let cmp = unsafe { GenericByteViewArray::compare_unchecked(array,
item, array, acc) };
+ if cmp == swap_cond {
+ item
+ } else {
+ acc
+ }
+ });
+ // SAFETY: idx came from valid range `0..array.len()`
+ unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
+ } else {
+ let nulls = array.nulls().unwrap();
+
+ let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| {
+ let cmp =
+ unsafe { GenericByteViewArray::compare_unchecked(array, idx,
array, acc_idx) };
+ if cmp == swap_cond {
+ idx
+ } else {
+ acc_idx
+ }
+ });
+
+ // SAFETY: idx came from valid range `0..array.len()`
+ unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
+ }
+}
+
/// Returns the maximum value in the binary array, according to the natural
order.
pub fn max_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) ->
Option<&[u8]> {
min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
@@ -432,7 +475,7 @@ pub fn max_binary<T: OffsetSizeTrait>(array:
&GenericBinaryArray<T>) -> Option<&
/// Returns the maximum value in the binary view array, according to the
natural order.
pub fn max_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
- min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
+ min_max_view_helper(array, Ordering::Greater)
}
/// Returns the minimum value in the binary array, according to the natural
order.
@@ -442,7 +485,7 @@ pub fn min_binary<T: OffsetSizeTrait>(array:
&GenericBinaryArray<T>) -> Option<&
/// Returns the minimum value in the binary view array, according to the
natural order.
pub fn min_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
- min_max_helper::<&[u8], _, _>(array, |a, b| *a > *b)
+ min_max_view_helper(array, Ordering::Less)
}
/// Returns the maximum value in the string array, according to the natural
order.
@@ -452,7 +495,7 @@ pub fn max_string<T: OffsetSizeTrait>(array:
&GenericStringArray<T>) -> Option<&
/// Returns the maximum value in the string view array, according to the
natural order.
pub fn max_string_view(array: &StringViewArray) -> Option<&str> {
- min_max_helper::<&str, _, _>(array, |a, b| *a < *b)
+ min_max_view_helper(array, Ordering::Greater)
}
/// Returns the minimum value in the string array, according to the natural
order.
@@ -462,7 +505,7 @@ pub fn min_string<T: OffsetSizeTrait>(array:
&GenericStringArray<T>) -> Option<&
/// Returns the minimum value in the string view array, according to the
natural order.
pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
- min_max_helper::<&str, _, _>(array, |a, b| *a > *b)
+ min_max_view_helper(array, Ordering::Less)
}
/// Returns the sum of values in the array.
diff --git a/arrow-array/src/array/byte_view_array.rs
b/arrow-array/src/array/byte_view_array.rs
index 7017add49..bd8c0cebc 100644
--- a/arrow-array/src/array/byte_view_array.rs
+++ b/arrow-array/src/array/byte_view_array.rs
@@ -336,6 +336,66 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
builder.finish()
}
+
+ /// Comparing two [`GenericByteViewArray`] at index `left_idx` and
`right_idx`
+ ///
+ /// Comparing two ByteView types are non-trivial.
+ /// It takes a bit of patience to understand why we don't just compare two
&[u8] directly.
+ ///
+ /// ByteView types give us the following two advantages, and we need to be
careful not to lose them:
+ /// (1) For string/byte smaller than 12 bytes, the entire data is inlined
in the view.
+ /// Meaning that reading one array element requires only one memory
access
+ /// (two memory access required for StringArray, one for offset
buffer, the other for value buffer).
+ ///
+ /// (2) For string/byte larger than 12 bytes, we can still be faster than
(for certain operations) StringArray/ByteArray,
+ /// thanks to the inlined 4 bytes.
+ /// Consider equality check:
+ /// If the first four bytes of the two strings are different, we can
return false immediately (with just one memory access).
+ ///
+ /// If we directly compare two &[u8], we materialize the entire string
(i.e., make multiple memory accesses), which might be unnecessary.
+ /// - Most of the time (eq, ord), we only need to look at the first 4
bytes to know the answer,
+ /// e.g., if the inlined 4 bytes are different, we can directly return
unequal without looking at the full string.
+ ///
+ /// # Order check flow
+ /// (1) if both string are smaller than 12 bytes, we can directly compare
the data inlined to the view.
+ /// (2) if any of the string is larger than 12 bytes, we need to compare
the full string.
+ /// (2.1) if the inlined 4 bytes are different, we can return the
result immediately.
+ /// (2.2) o.w., we need to compare the full string.
+ ///
+ /// # Safety
+ /// The left/right_idx must within range of each array
+ pub unsafe fn compare_unchecked(
+ left: &GenericByteViewArray<T>,
+ left_idx: usize,
+ right: &GenericByteViewArray<T>,
+ right_idx: usize,
+ ) -> std::cmp::Ordering {
+ let l_view = left.views().get_unchecked(left_idx);
+ let l_len = *l_view as u32;
+
+ let r_view = right.views().get_unchecked(right_idx);
+ let r_len = *r_view as u32;
+
+ if l_len <= 12 && r_len <= 12 {
+ let l_data = unsafe {
GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
+ let r_data = unsafe {
GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
+ return l_data.cmp(r_data);
+ }
+
+ // one of the string is larger than 12 bytes,
+ // we then try to compare the inlined data first
+ let l_inlined_data = unsafe {
GenericByteViewArray::<T>::inline_value(l_view, 4) };
+ let r_inlined_data = unsafe {
GenericByteViewArray::<T>::inline_value(r_view, 4) };
+ if r_inlined_data != l_inlined_data {
+ return l_inlined_data.cmp(r_inlined_data);
+ }
+
+ // unfortunately, we need to compare the full data
+ let l_full_data: &[u8] = unsafe {
left.value_unchecked(left_idx).as_ref() };
+ let r_full_data: &[u8] = unsafe {
right.value_unchecked(right_idx).as_ref() };
+
+ l_full_data.cmp(r_full_data)
+ }
}
impl<T: ByteViewType + ?Sized> Debug for GenericByteViewArray<T> {
diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs
index 26eb0d8d6..9d7874c64 100644
--- a/arrow-ord/src/cmp.rs
+++ b/arrow-ord/src/cmp.rs
@@ -579,13 +579,13 @@ impl<'a, T: ByteViewType> ArrayOrd for &'a
GenericByteViewArray<T> {
return false;
}
- unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_eq() }
+ unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0,
r.1).is_eq() }
}
fn is_lt(l: Self::Item, r: Self::Item) -> bool {
// # Safety
// The index is within bounds as it is checked in value()
- unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_lt() }
+ unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0,
r.1).is_lt() }
}
fn len(&self) -> usize {
@@ -626,7 +626,7 @@ pub fn compare_byte_view<T: ByteViewType>(
) -> std::cmp::Ordering {
assert!(left_idx < left.len());
assert!(right_idx < right.len());
- unsafe { compare_byte_view_unchecked(left, left_idx, right, right_idx) }
+ unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right,
right_idx) }
}
/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
@@ -656,6 +656,7 @@ pub fn compare_byte_view<T: ByteViewType>(
///
/// # Safety
/// The left/right_idx must within range of each array
+#[deprecated(note = "Use `GenericByteViewArray::compare_unchecked` instead")]
pub unsafe fn compare_byte_view_unchecked<T: ByteViewType>(
left: &GenericByteViewArray<T>,
left_idx: usize,
diff --git a/arrow/benches/aggregate_kernels.rs
b/arrow/benches/aggregate_kernels.rs
index 9bb866f36..434bb4778 100644
--- a/arrow/benches/aggregate_kernels.rs
+++ b/arrow/benches/aggregate_kernels.rs
@@ -57,8 +57,8 @@ fn add_benchmark(c: &mut Criterion) {
primitive_benchmark::<Int64Type>(c, "int64");
{
- let nonnull_strings = create_string_array::<i32>(BATCH_SIZE, 0.0);
- let nullable_strings = create_string_array::<i32>(BATCH_SIZE, 0.5);
+ let nonnull_strings = create_string_array_with_len::<i32>(BATCH_SIZE,
0.0, 16);
+ let nullable_strings = create_string_array_with_len::<i32>(BATCH_SIZE,
0.5, 16);
c.benchmark_group("string")
.throughput(Throughput::Elements(BATCH_SIZE as u64))
.bench_function("min nonnull", |b| b.iter(||
min_string(&nonnull_strings)))
@@ -67,6 +67,25 @@ fn add_benchmark(c: &mut Criterion) {
.bench_function("max nullable", |b| b.iter(||
max_string(&nullable_strings)));
}
+ {
+ let nonnull_strings = create_string_view_array_with_len(BATCH_SIZE,
0.0, 16, false);
+ let nullable_strings = create_string_view_array_with_len(BATCH_SIZE,
0.5, 16, false);
+ c.benchmark_group("string view")
+ .throughput(Throughput::Elements(BATCH_SIZE as u64))
+ .bench_function("min nonnull", |b| {
+ b.iter(|| min_string_view(&nonnull_strings))
+ })
+ .bench_function("max nonnull", |b| {
+ b.iter(|| max_string_view(&nonnull_strings))
+ })
+ .bench_function("min nullable", |b| {
+ b.iter(|| min_string_view(&nullable_strings))
+ })
+ .bench_function("max nullable", |b| {
+ b.iter(|| max_string_view(&nullable_strings))
+ });
+ }
+
{
let nonnull_bools_mixed = create_boolean_array(BATCH_SIZE, 0.0, 0.5);
let nonnull_bools_all_false = create_boolean_array(BATCH_SIZE, 0.0,
0.0);